{
 "cells": [
  {
   "cell_type": "code",
   "id": "initial_id",
   "metadata": {
    "collapsed": true,
    "ExecuteTime": {
     "end_time": "2025-03-07T02:43:07.036122Z",
     "start_time": "2025-03-07T02:43:03.174360Z"
    }
   },
   "source": [
    "import matplotlib as mpl\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "import numpy as np\n",
    "import sklearn\n",
    "import pandas as pd\n",
    "import os\n",
    "import sys\n",
    "import time\n",
    "from tqdm.auto import tqdm\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "print(sys.version_info)\n",
    "for module in mpl, np, pd, sklearn, torch:\n",
    "    print(module.__name__, module.__version__)\n",
    "\n",
    "device = torch.device(\"cuda:0\") if torch.cuda.is_available() else torch.device(\"cpu\")\n",
    "print(device)\n",
    "\n",
    "seed = 42\n",
    "torch.manual_seed(seed)\n",
    "torch.cuda.manual_seed_all(seed)\n",
    "np.random.seed(seed)\n"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "sys.version_info(major=3, minor=12, micro=3, releaselevel='final', serial=0)\n",
      "matplotlib 3.10.0\n",
      "numpy 1.26.4\n",
      "pandas 2.2.3\n",
      "sklearn 1.6.0\n",
      "torch 2.3.1+cu121\n",
      "cuda:0\n"
     ]
    }
   ],
   "execution_count": 1
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "# 数据加载",
   "id": "e370168f5260bd08"
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-03-07T03:52:56.024328Z",
     "start_time": "2025-03-07T03:52:56.018463Z"
    }
   },
   "cell_type": "code",
   "source": [
    "# 导入unicodedata模块，用于处理Unicode字符\n",
    "import unicodedata\n",
    "# 导入re模块，用于正则表达式操作\n",
    "import re\n",
    "# 导入train_test_split函数，用于数据集划分\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "\n",
    "# 西班牙语中有一些特殊符号，所以我们需要unicode转ascii，这样值变小了，因为unicode太大\n",
    "def unicode_to_ascii(s):\n",
    "    \"\"\"\n",
    "    将Unicode字符串转换为ASCII字符串\n",
    "    :param s: Unicode字符串\n",
    "    :return: ASCII字符串\n",
    "    \"\"\"\n",
    "\n",
    "    # 使用unicodedata.normalize('NFD', s)将字符串s规范化，分解为基本字符和组合字符\n",
    "    # 然后过滤掉所有组合字符（Unicode类别为'Mn'的字符），并将剩余的字符拼接成新的字符串\n",
    "    return ''.join(c for c in unicodedata.normalize('NFD', s) if unicodedata.category(c) != 'Mn')\n",
    "\n",
    "\n",
    "def preprocess_sentence(w):\n",
    "    # 将句子转换为小写并去除首尾空格\n",
    "    w = unicode_to_ascii(w.lower().strip())\n",
    "\n",
    "    # 在标点符号前后添加空格，以便后续处理\n",
    "    w = re.sub(r\"([?.!,¿])\", r\" \\1 \", w)\n",
    "\n",
    "    # 将多个连续空格替换为单个空格\n",
    "    w = re.sub(r'[\" \"]+', \" \", w)\n",
    "\n",
    "    # 去除所有非字母和标点符号的字符，并用空格替换\n",
    "    w = re.sub(r\"[^a-zA-Z?.!,¿]+\", \" \", w)\n",
    "\n",
    "    # 去除句子末尾的空格并再次去除首尾空格\n",
    "    w = w.rstrip().strip()\n",
    "\n",
    "    return w\n",
    "\n",
    "\n",
    "en_sentence = \"May I borrow this book?\"  # 定义一个英文句子\n",
    "sp_sentence = \"¿Puedo tomar prestado este libro?\"  # 定义一个西班牙语句子\n",
    "\n",
    "print(unicode_to_ascii(en_sentence))  # 打印转换后的英文句子\n",
    "print(unicode_to_ascii(sp_sentence))  # 打印转换后的西班牙语句子\n",
    "\n",
    "print(preprocess_sentence(en_sentence))\n",
    "print(preprocess_sentence(sp_sentence))\n",
    "print(preprocess_sentence(sp_sentence).encode('utf-8'))  #¿是占用两个字节的"
   ],
   "id": "29ff7a62513abfea",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "May I borrow this book?\n",
      "¿Puedo tomar prestado este libro?\n",
      "may i borrow this book ?\n",
      "¿ puedo tomar prestado este libro ?\n",
      "b'\\xc2\\xbf puedo tomar prestado este libro ?'\n"
     ]
    }
   ],
   "execution_count": 3
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "# Datasets\n",
   "id": "240cbbfc5cd7ae10"
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-03-07T04:40:31.831720Z",
     "start_time": "2025-03-07T04:40:29.713095Z"
    }
   },
   "cell_type": "code",
   "source": [
    "from pathlib import Path\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "\n",
    "\n",
    "class LangPairDataset(Dataset):\n",
    "    fpath = Path(r\"./data_spa_en/spa.txt\")  #数据文件路径\n",
    "    cache_path = Path(r\"./.cache/lang_pair.npy\")  #缓存文件路径\n",
    "\n",
    "    # 按照index随机分割数据集，9:1划分训练集和测试集\n",
    "    split_index = np.random.choice(a=['train', 'test'], replace=True, p=[0.9, 0.1], size=118964)\n",
    "\n",
    "    def __init__(self, mode=\"train\", cache=False):\n",
    "        if cache or not self.cache_path.exists():  #如果没有缓存，或者缓存不存在，就处理一下数据\n",
    "\n",
    "            # 创建缓存文件的父目录，如果目录已经存在则不报错，同时递归创建所有需要的父目录\n",
    "            # self.cache_path.parent: 获取缓存文件路径的父目录,parent属性返回该路径的上一级目录。\n",
    "            # mkdir(parents=True, exist_ok=True): 创建目录。\n",
    "            # parents=True表示如果父目录不存在，则递归创建所有需要的父目录；\n",
    "            # exist_ok=True表示如果目录已经存在，则不会抛出错误，允许继续执行。\n",
    "            self.cache_path.parent.mkdir(parents=True, exist_ok=True)\n",
    "            with open(self.fpath, \"r\", encoding=\"utf8\") as file:\n",
    "                # # 从文件中读取所有行，并将每一行作为一个字符串存储在列表 lines 中\n",
    "                lines = file.readlines()\n",
    "\n",
    "                # 将每一行按制表符（'\\t'）分割，得到多个单词或句子，然后对每个单词或句子进行预处理\n",
    "                # 最终生成一个二维列表，每个子列表包含预处理后的单词或句子\n",
    "                # l.split('\\t'): 将每一行按制表符（'\\t'）分割，得到一个包含多个单词或句子的列表\n",
    "                # preprocess_sentence(w): 对分割后的每个单词或句子进行预处理，例如去除多余空格、标准化字符等\n",
    "                # 使用列表推导式，对每一行进行上述操作，最终生成一个二维列表 lang_pair，其中每个子列表包含预处理后的单词或句子\n",
    "                lang_pair = [[preprocess_sentence(w) for w in l.split('\\t')] for l in\n",
    "                             lines]  # list[[trg,src],[trg,src],...]\n",
    "\n",
    "                # 使用 zip(*lang_pair) 将二维列表 lang_pair 解包，分别提取目标语言（trg）和源语言（src）的句子\n",
    "                # 终 trg 和 src 分别是一个元组，包含所有目标语言和源语言的句子\n",
    "                trg, src = zip(*lang_pair)\n",
    "                trg = np.array(trg)  #转换为numpy数组\n",
    "                src = np.array(src)  #转换为numpy数组\n",
    "                np.save(self.cache_path, {\"trg\": trg, \"src\": src})  #保存为npy文件,方便下次直接读取,不用再处理\n",
    "        else:\n",
    "            # 从缓存文件中加载数据，allow_pickle=True 允许加载包含Python对象的文件\n",
    "            # np.load() 用于加载 .npy 文件，返回一个字典对象\n",
    "            # .item() 方法将 numpy 数组转换为 Python 字典\n",
    "            lang_pair = np.load(self.cache_path, allow_pickle=True).item()\n",
    "            trg = lang_pair[\"trg\"]\n",
    "            src = lang_pair[\"src\"]\n",
    "\n",
    "        self.trg = trg[self.split_index == mode]  #按照index拿到训练集的 标签语言 --英语  \n",
    "        self.src = src[self.split_index == mode]  #按照index拿到训练集的源语言 --西班牙\n",
    "\n",
    "    def __getitem__(self, index):\n",
    "        return self.src[index], self.trg[index]\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.src)\n",
    "\n",
    "\n",
    "train_ds = LangPairDataset(\"train\")\n",
    "test_ds = LangPairDataset(\"test\")"
   ],
   "id": "988963b3f6890ab6",
   "outputs": [],
   "execution_count": 5
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-03-07T04:40:32.586376Z",
     "start_time": "2025-03-07T04:40:32.583318Z"
    }
   },
   "cell_type": "code",
   "source": [
    "#zip例子\n",
    "a = [[1, 2], [4, 5], [7, 8]]\n",
    "zipped = list(zip(*a))\n",
    "print(zipped)"
   ],
   "id": "a5871be204fb2d28",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[(1, 4, 7), (2, 5, 8)]\n"
     ]
    }
   ],
   "execution_count": 6
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-03-07T04:40:56.603796Z",
     "start_time": "2025-03-07T04:40:56.600441Z"
    }
   },
   "cell_type": "code",
   "source": "print(\"source: {}\\ntarget: {}\".format(*train_ds[-1]))",
   "id": "2f60d4dbc35b15f5",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "source: si quieres sonar como un hablante nativo , debes estar dispuesto a practicar diciendo la misma frase una y otra vez de la misma manera en que un musico de banjo practica el mismo fraseo una y otra vez hasta que lo puedan tocar correctamente y en el tiempo esperado .\n",
      "target: if you want to sound like a native speaker , you must be willing to practice saying the same sentence over and over in the same way that banjo players practice the same phrase over and over until they can play it correctly and at the desired tempo .\n"
     ]
    }
   ],
   "execution_count": 7
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "# Tokenizer\n",
    "\n",
    "这里有两种处理方式，分别对应着 encoder 和 decoder 的 word embedding 是否共享，这里实现不共享的方案。"
   ],
   "id": "6991a05c17a53348"
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-03-07T04:56:19.043743Z",
     "start_time": "2025-03-07T04:56:18.674555Z"
    }
   },
   "cell_type": "code",
   "source": [
    "# Counter 是一个用于计数的工具，它可以方便地对可迭代对象（如列表、字符串等）中的元素进行计数，并返回一个字典形式的计数结果，其中键是元素，值是该元素出现的次数。\n",
    "from collections import Counter\n",
    "\n",
    "\n",
    "def get_word_idx(ds, mode=\"src\", threshold=2):\n",
    "    word2idx = {\n",
    "        \"[PAD]\": 0,  # 填充 token\n",
    "        \"[BOS]\": 1,  # begin of sentence\n",
    "        \"[UNK]\": 2,  # 未知 token\n",
    "        \"[EOS]\": 3,  # end of sentence\n",
    "    }\n",
    "    idx2word = {value: key for key, value in word2idx.items()}\n",
    "    index = len(idx2word)\n",
    "    threshold = 1\n",
    "\n",
    "    # 如果数据集有很多个G，那是用for循环的，不能' '.join\n",
    "    # 根据 mode 选择源语言或目标语言的句子，将所有句子连接成一个字符串，然后拆分成单词列表\n",
    "    # 如果 mode 为 \"src\"，则选择 pair 中的第一个元素（索引为 0）；否则选择第二个元素（索引为 1）\n",
    "    # \" \".join(...):将上一步生成的列表中的所有句子用空格连接成一个长字符串。\n",
    "    # .split():将连接后的长字符串按空格分割，生成一个单词列表 word_list。\n",
    "    word_list = \" \".join([pair[0 if mode == \"src\" else 1] for pair in ds]).split()\n",
    "\n",
    "    counter = Counter(word_list)  #统计词频,counter类似字典，key是单词，value是出现次数\n",
    "    print(\"word count:\", len(counter))\n",
    "\n",
    "    for token, count in counter.items():\n",
    "        if count >= threshold:  #出现次数大于阈值的token加入词表\n",
    "            word2idx[token] = index  #加入词表\n",
    "            idx2word[index] = token  #加入反向词表\n",
    "            index += 1\n",
    "\n",
    "    return word2idx, idx2word\n",
    "\n",
    "\n",
    "src_word2idx, src_idx2word = get_word_idx(train_ds, \"src\")  #源语言词表  西班牙语\n",
    "trg_word2idx, trg_idx2word = get_word_idx(train_ds, \"trg\")  #目标语言词表 英语"
   ],
   "id": "929597786f5f1f37",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "word count: 23774\n",
      "word count: 12465\n"
     ]
    }
   ],
   "execution_count": 8
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-03-07T06:21:40.750622Z",
     "start_time": "2025-03-07T06:21:40.745168Z"
    }
   },
   "cell_type": "code",
   "source": [
    "class Tokenizer:\n",
    "    def __init__(self, word2idx, idx2word, max_length=500, pad_idx=0, bos_idx=1, eos_idx=3, unk_idx=2):\n",
    "        self.word2idx = word2idx\n",
    "        self.idx2word = idx2word\n",
    "        self.max_length = max_length\n",
    "        self.pad_idx = pad_idx\n",
    "        self.bos_idx = bos_idx\n",
    "        self.eos_idx = eos_idx\n",
    "        self.unk_idx = unk_idx\n",
    "\n",
    "    def encode(self, text_list, padding_first=False, add_bos=True, add_eos=True, return_mask=False):\n",
    "        \"\"\"\n",
    "        param text_list: 一个列表，包含多个句子，每个句子是一个列表，包含多个单词\n",
    "        param padding_first: 是否padding加载前面\n",
    "        param add_bos: 是否添加bos\n",
    "        param add_eos: 是否添加eos\n",
    "        param return_mask: 是否返回mask(掩码），mask用于指示哪些是padding的，哪些是真实的token\n",
    "        return: input_ids, masks\n",
    "        \"\"\"\n",
    "        # 计算最大长度\n",
    "        max_length = min(self.max_length, add_eos + add_bos + max([len(text) for text in text_list]))\n",
    "        indices_list = []\n",
    "\n",
    "        # 遍历每个句子\n",
    "        for text in text_list:\n",
    "\n",
    "            # 遍历每个单词，如果词表中有这个词，就用词表中的index，否则用unk_idx代替\n",
    "            indices = [self.word2idx.get(word, self.unk_idx) for word in text[:max_length - add_bos - add_eos]]\n",
    "\n",
    "            if add_bos:\n",
    "                indices = [self.bos_idx] + indices\n",
    "            if add_eos:\n",
    "                indices = indices + [self.eos_idx]\n",
    "            if padding_first:  #padding加载前面\n",
    "                indices = [self.pad_idx] * (max_length - len(indices)) + indices\n",
    "            else:  #padding加载后面\n",
    "                indices = indices + [self.pad_idx] * (max_length - len(indices))\n",
    "            indices_list.append(indices)\n",
    "            input_ids = torch.tensor(indices_list)\n",
    "\n",
    "        # mask是一个和input_ids一样大小的tensor，0代表token，1代表padding，mask用于去除padding的影响\n",
    "        masks = (input_ids == self.pad_idx).to(dtype=torch.int64)\n",
    "\n",
    "        # 返回input_ids和mask,如果return_mask为False，则只返回input_ids\n",
    "        return input_ids if not return_mask else (input_ids, masks)\n",
    "\n",
    "    def decode(self, indices_list, remove_bos=True, remove_eos=True, remove_pad=True, split=False):\n",
    "        text_list = []\n",
    "        for indices in indices_list:\n",
    "            text = []\n",
    "            for index in indices:\n",
    "                # 如果词表中有这个词，就用词表中的词，否则用unk_idx代替\n",
    "                word = self.idx2word.get(index, \"[UNK]\")\n",
    "                if remove_bos and word == \"[BOS]\":\n",
    "                    continue\n",
    "                if remove_eos and word == \"[EOS]\":  #如果到达eos，就结束\n",
    "                    break\n",
    "                if remove_pad and word == \"[PAD]\":  #如果到达pad，就结束\n",
    "                    break\n",
    "                text.append(word)  #单词添加到列表中\n",
    "            # 把列表中的单词拼接，变为一个句子\n",
    "            # 如果 split 为 False，则将 text 列表中的所有单词用空格连接成一个字符串，然后将这个字符串添加到 text_list 中。\n",
    "            # 如果 split 为 True，则直接将 text 列表添加到 text_list 中，而不会将其转换为字符串。\n",
    "            text_list.append(\" \".join(text) if not split else text)\n",
    "        return text_list\n",
    "\n",
    "\n",
    "#使用两个分词器（tokenizer）的好处之一是可以减少嵌入层（embedding layer）的参数量。\n",
    "# 通常在机器翻译任务中，源语言（src）和目标语言（trg）的词汇表是不同的，因此需要分别为每种语言创建一个独立的嵌入层。\n",
    "# 这样做的好处是每个语言的嵌入层只包含该语言的词汇，避免了不必要的参数，从而减少了模型的复杂性和训练资源的消耗。\n",
    "src_tokenizer = Tokenizer(word2idx=src_word2idx, idx2word=src_idx2word)  #源语言tokenizer\n",
    "trg_tokenizer = Tokenizer(word2idx=trg_word2idx, idx2word=trg_idx2word)  #目标语言tokenizer"
   ],
   "id": "b7f6380b0a0f7397",
   "outputs": [],
   "execution_count": 9
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-03-07T06:22:07.582891Z",
     "start_time": "2025-03-07T06:22:07.558456Z"
    }
   },
   "cell_type": "code",
   "source": [
    "# trg_tokenizer.encode([[\"hello\"], [\"hello\", \"world\"]], add_bos=True, add_eos=False,return_mask=True)\n",
    "raw_text = [\"hello world\".split(), \"tokenize text datas with batch\".split(), \"this is a test\".split()]\n",
    "indices, mask = trg_tokenizer.encode(raw_text, padding_first=False, add_bos=True, add_eos=True, return_mask=True)\n",
    "decode_text = trg_tokenizer.decode(indices.tolist(), remove_bos=False, remove_eos=False, remove_pad=False)\n",
    "print(\"raw text\" + '-' * 10)\n",
    "for raw in raw_text:\n",
    "    print(raw)\n",
    "print(\"mask\" + '-' * 10)\n",
    "for m in mask:\n",
    "    print(m)\n",
    "print(\"indices\" + '-' * 10)\n",
    "for index in indices:\n",
    "    print(index)\n",
    "print(\"decode text\" + '-' * 10)\n",
    "for decode in decode_text:\n",
    "    print(decode)"
   ],
   "id": "ac3891480fb2b51f",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "raw text----------\n",
      "['hello', 'world']\n",
      "['tokenize', 'text', 'datas', 'with', 'batch']\n",
      "['this', 'is', 'a', 'test']\n",
      "mask----------\n",
      "tensor([0, 0, 0, 0, 1, 1, 1])\n",
      "tensor([0, 0, 0, 0, 0, 0, 0])\n",
      "tensor([0, 0, 0, 0, 0, 0, 1])\n",
      "indices----------\n",
      "tensor([   1,   17, 3224,    3,    0,    0,    0])\n",
      "tensor([   1,    2, 3870,    2,  538,    2,    3])\n",
      "tensor([   1,  121,  233,  107, 1262,    3,    0])\n",
      "decode text----------\n",
      "[BOS] hello world [EOS] [PAD] [PAD] [PAD]\n",
      "[BOS] [UNK] text [UNK] with [UNK] [EOS]\n",
      "[BOS] this is a test [EOS] [PAD]\n"
     ]
    }
   ],
   "execution_count": 10
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "# DataLoader",
   "id": "c215887cdf11764a"
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-03-07T06:28:17.191151Z",
     "start_time": "2025-03-07T06:28:17.186675Z"
    }
   },
   "cell_type": "code",
   "source": [
    "def collate_fct(batch):\n",
    "    # 对于每个批次中的样本对，将源语言句子按空格分割成单词列表\n",
    "    src_words = [pair[0].split() for pair in batch]\n",
    "    # 对于每个批次中的样本对，将目标语言句子按空格分割成单词列表\n",
    "    trg_words = [pair[1].split() for pair in batch]\n",
    "\n",
    "    # 使用源语言的tokenizer将源语言单词列表编码为模型输入，并返回输入及其掩码\n",
    "    encoder_inputs, encoder_inputs_mask = src_tokenizer.encode(\n",
    "        src_words, padding_first=True, add_bos=True, add_eos=True, return_mask=True\n",
    "    )\n",
    "\n",
    "    # 使用目标语言的tokenizer将目标语言单词列表编码为解码器输入，不返回掩码\n",
    "    decoder_inputs = trg_tokenizer.encode(\n",
    "        trg_words, padding_first=False, add_bos=True, add_eos=False, return_mask=False,\n",
    "    )\n",
    "\n",
    "    # 使用目标语言的tokenizer将目标语言单词列表编码为解码器标签，并返回标签及其掩码\n",
    "    decoder_labels, decoder_labels_mask = trg_tokenizer.encode(\n",
    "        trg_words, padding_first=False, add_bos=False, add_eos=True, return_mask=True\n",
    "    )\n",
    "\n",
    "    # 将编码后的输入和标签及其对应的掩码移动到指定设备上，并以字典形式返回\n",
    "    return {\n",
    "        \"encoder_inputs\": encoder_inputs.to(device=device),\n",
    "        \"encoder_inputs_mask\": encoder_inputs_mask.to(device=device),\n",
    "        \"decoder_inputs\": decoder_inputs.to(device=device),\n",
    "        \"decoder_labels\": decoder_labels.to(device=device),\n",
    "        \"decoder_labels_mask\": decoder_labels_mask.to(device=device),\n",
    "    }\n",
    "\n",
    "#当返回的数据较多时，用dict返回比较合理"
   ],
   "id": "77bdabe3b205edf9",
   "outputs": [],
   "execution_count": 11
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-03-07T06:28:50.717135Z",
     "start_time": "2025-03-07T06:28:50.593093Z"
    }
   },
   "cell_type": "code",
   "source": [
    "sample_dl = DataLoader(train_ds, batch_size=2, shuffle=True, collate_fn=collate_fct)\n",
    "\n",
    "for batch in sample_dl:\n",
    "    for key, value in batch.items():\n",
    "        print(key)\n",
    "        print(value)\n",
    "    break"
   ],
   "id": "591a7ac74391eb7",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "encoder_inputs\n",
      "tensor([[   0,    0,    0,    0,    1,  350, 4002, 2826, 2827,    5,    3],\n",
      "        [   1,   12, 2266,  706,   80,  294,   88,   83,  297,   14,    3]],\n",
      "       device='cuda:0')\n",
      "encoder_inputs_mask\n",
      "tensor([[1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0],\n",
      "        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], device='cuda:0')\n",
      "decoder_inputs\n",
      "tensor([[   1,  197,  756, 1413,    5,    0,    0,    0,    0,    0],\n",
      "        [   1,  332,   90, 1090,  443,  158,   31,  680, 3688,   10]],\n",
      "       device='cuda:0')\n",
      "decoder_labels\n",
      "tensor([[ 197,  756, 1413,    5,    3,    0,    0,    0,    0,    0],\n",
      "        [ 332,   90, 1090,  443,  158,   31,  680, 3688,   10,    3]],\n",
      "       device='cuda:0')\n",
      "decoder_labels_mask\n",
      "tensor([[0, 0, 0, 0, 0, 1, 1, 1, 1, 1],\n",
      "        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], device='cuda:0')\n"
     ]
    }
   ],
   "execution_count": 12
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "# 定义模型",
   "id": "903a2b56d1bfdff9"
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-03-07T06:33:39.011980Z",
     "start_time": "2025-03-07T06:33:39.008014Z"
    }
   },
   "cell_type": "code",
   "source": [
    "class Encoder(nn.Module):\n",
    "    def __init__(\n",
    "            self,\n",
    "            vocab_size,\n",
    "            embedding_dim=256,\n",
    "            hidden_dim=1024,\n",
    "            num_layers=1,\n",
    "    ):\n",
    "        super().__init__()\n",
    "        self.embedding = nn.Embedding(vocab_size, embedding_dim)\n",
    "        self.gru = nn.GRU(embedding_dim, hidden_dim, num_layers=num_layers, batch_first=True)\n",
    "\n",
    "    def forward(self, encoder_inputs):\n",
    "        # encoder_inputs.shape = [batch size, sequence length]\n",
    "        embeds = self.embedding(encoder_inputs)\n",
    "        # embeds.shape = [batch size, sequence length, embedding_dim]->[batch size, sequence length, hidden_dim]\n",
    "        seq_output, hidden = self.gru(embeds)\n",
    "        # 输出 seq_output 的 shape 是 [batch size, sequence length, hidden_dim]，表示每个时间步的隐藏状态输出\n",
    "        # 输出 hidden 的 shape 是 [num_layers, batch size, hidden_dim]，表示最后一个时间步的隐藏状态，对于每个层\n",
    "        return seq_output, hidden"
   ],
   "id": "28bf4825e40b77f3",
   "outputs": [],
   "execution_count": 13
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-03-07T06:36:57.758480Z",
     "start_time": "2025-03-07T06:36:57.549596Z"
    }
   },
   "cell_type": "code",
   "source": [
    "#把上面的Encoder写一个例子，看看输出的shape\n",
    "# 实例化一个Encoder\n",
    "encoder = Encoder(vocab_size=100, embedding_dim=256, hidden_dim=1024, num_layers=4)\n",
    "# 输入一个batch的encoder_inputs\n",
    "encoder_inputs = torch.randint(0, 100, (2, 50))\n",
    "encoder_outputs, hidden = encoder(encoder_inputs)\n",
    "print(f'encoder_outputs.shape：{encoder_outputs.shape}')\n",
    "print(f'hidden.shape：{hidden.shape}')"
   ],
   "id": "33781f79cdfee08b",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "encoder_outputs.shape：torch.Size([2, 50, 1024])\n",
      "hidden.shape：torch.Size([4, 2, 1024])\n"
     ]
    }
   ],
   "execution_count": 19
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-03-07T06:37:20.811749Z",
     "start_time": "2025-03-07T06:37:20.805230Z"
    }
   },
   "cell_type": "code",
   "source": [
    "query1 = torch.randn(2, 1024)\n",
    "query1.unsqueeze(1).shape  #增加维度"
   ],
   "id": "410c4059c256bac5",
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([2, 1, 1024])"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 20
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "EO: encoder 各个位置的输出\n",
    "\n",
    "◆H: decoder 某一步的隐含状态\n",
    "\n",
    "◆FC: 全连接层\n",
    "\n",
    "◆X: decoder 的一个输入\n",
    "\n",
    "score = FC(tanh(FC(EO) + FC(H)))"
   ],
   "id": "7af8b21d8805f7f4"
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-03-07T06:47:26.118241Z",
     "start_time": "2025-03-07T06:47:26.112787Z"
    }
   },
   "cell_type": "code",
   "source": [
    "class BahdanauAttention(nn.Module):\n",
    "    def __init__(self, hidden_dim=1024):\n",
    "        super().__init__()\n",
    "        self.Wk = nn.Linear(hidden_dim, hidden_dim)  # FC对keys做运算，encoder的输出EO\n",
    "        self.Wq = nn.Linear(hidden_dim, hidden_dim)  # FC对query做运算，decoder的隐藏状态\n",
    "        self.V = nn.Linear(hidden_dim, 1)  # FC对tanh(Wk+Wq)做运算，得到注意力分数\n",
    "\n",
    "    def forward(self, query, keys, values, attn_mask=None):\n",
    "        \"\"\"\n",
    "        正向传播\n",
    "        param query: H，是decoder的隐藏状态，shape = [batch size, hidden_dim]\n",
    "        param keys: EO  [batch size, sequence length, hidden_dim]\n",
    "        param values: EO  [batch size, sequence length, hidden_dim]\n",
    "        param attn_mask:[batch size, sequence length]\n",
    "        return:\n",
    "        \"\"\"\n",
    "        # Bahdanau注意力机制能够为每个key生成一个注意力分数，从而使得模型能够根据这些分数来加权组合values，以生成最终的上下文向量。这个上下文向量将用于帮助模型生成更准确的输出\n",
    "        # query.shape = [batch size, hidden_dim] -->通过unsqueeze(-2)增加维度 [batch size, 1, hidden_dim]\n",
    "        # self.Wq(query.unsqueeze(-2)).shape = [batch size, 1, hidden_dim]\n",
    "        # self.Wk(keys).shape = [batch size, sequence length, hidden_dim]\n",
    "        # self.Wk(keys) + self.Wq(query.unsqueeze(-2)).shape = [batch size, sequence length, hidden_dim]\n",
    "        # F.tanh(self.Wk(keys) + self.Wq(query.unsqueeze(-2))).shape = [batch size, sequence length, hidden_dim]\n",
    "        # self.V(F.tanh(self.Wk(keys) + self.Wq(query.unsqueeze(-2)))).shape = [batch size, sequence length, 1]\n",
    "        scores = self.V(F.tanh(self.Wk(keys) + self.Wq(query.unsqueeze(-2))))\n",
    "\n",
    "        if attn_mask is not None:\n",
    "            # 如果提供了注意力掩码，则对其进行处理以适应注意力分数的形状\n",
    "            # attn_mask.shape = [batch size, sequence length] --> [batch size, sequence length, 1]\n",
    "            # 将掩码值乘以 -1e16 以在后续的 softmax 操作中忽略这些位置\n",
    "            attn_mask = (attn_mask.unsqueeze(-1)) * -1e16\n",
    "            # 将掩码应用到注意力分数上\n",
    "            # scores.shape = [batch size, sequence length, 1]\n",
    "            # attn_mask.shape = [batch size, sequence length, 1]\n",
    "            # scores + attn_mask 会使得掩码位置的分数非常小，从而在 softmax 后接近于0，忽略这些位置的影响\n",
    "            scores = scores + attn_mask\n",
    "\n",
    "        # 对注意力分数进行 softmax 操作，得到注意力权重\n",
    "        # 这里的 dim=-2 表示在 sequence length 维度上进行 softmax，得到每个位置的重要性权重\n",
    "        # scores.shape = [batch size, sequence length, 1] --> [batch size, 1, sequence length]\n",
    "        scores = F.softmax(scores, dim=-2)\n",
    "\n",
    "        # 计算上下文向量，即加权的编码器输出\n",
    "        # context_vector.shape = [batch size, hidden_dim]\n",
    "        # torch.mul(scores, values) 对注意力权重scores和编码器输出的 values 进行逐元素乘法\n",
    "        # .sum(dim=-2) 对 sequence length 维度求和，得到加权后的结果\n",
    "        context_vector = torch.mul(scores, values).sum(dim=-2)\n",
    "\n",
    "        # 返回上下文向量和注意力权重\n",
    "        return context_vector, scores"
   ],
   "id": "15cb3fcfde35e5c3",
   "outputs": [],
   "execution_count": 21
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-03-07T06:47:56.175764Z",
     "start_time": "2025-03-07T06:47:56.128512Z"
    }
   },
   "cell_type": "code",
   "source": [
    "#把上面的BahdanauAttention写一个例子，看看输出的shape\n",
    "attention = BahdanauAttention(hidden_dim=1024)\n",
    "query = torch.randn(2, 1024)  #Decoder的隐藏状态\n",
    "keys = torch.randn(2, 50, 1024)  #EO\n",
    "values = torch.randn(2, 50, 1024)  #EO\n",
    "attn_mask = torch.randint(0, 2, (2, 50))\n",
    "context_vector, scores = attention(query, keys, values, attn_mask)\n",
    "print(context_vector.shape)\n",
    "print(scores.shape)"
   ],
   "id": "cc6f7811492252d2",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([2, 1024])\n",
      "torch.Size([2, 50, 1])\n"
     ]
    }
   ],
   "execution_count": 22
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-03-08T00:39:52.643234Z",
     "start_time": "2025-03-08T00:39:52.637642Z"
    }
   },
   "cell_type": "code",
   "source": [
    "class Decoder(nn.Module):\n",
    "    def __init__(\n",
    "            self,\n",
    "            vocab_size,\n",
    "            embedding_dim=256,\n",
    "            hidden_dim=1024,\n",
    "            num_layers=1,\n",
    "    ):\n",
    "        super(Decoder, self).__init__()\n",
    "        # 定义一个嵌入层，将词汇表中的每个词映射为一个embedding_dim维的向量\n",
    "        self.embedding = nn.Embedding(vocab_size, embedding_dim)\n",
    "        # 定义一个GRU层，输入维度为embedding_dim + hidden_dim，输出维度为hidden_dim\n",
    "        self.gru = nn.GRU(embedding_dim + hidden_dim, hidden_dim, num_layers=num_layers, batch_first=True)\n",
    "        # 定义一个全连接层，将hidden_dim维的向量映射为vocab_size维的向量，用于生成每个词的概率分布\n",
    "        self.fc = nn.Linear(hidden_dim, vocab_size)\n",
    "        # 定义一个dropout层，用于随机丢弃部分神经元，防止过拟合\n",
    "        self.dropout = nn.Dropout(0.6)\n",
    "        # 定义一个Bahdanau注意力机制，用于计算decoder输入对encoder输出的关注度\n",
    "        self.attention = BahdanauAttention(hidden_dim)\n",
    "\n",
    "    def forward(self, decoder_input, hidden, encoder_outputs, attn_mask=None):\n",
    "        # 确保decoder_input的形状为[batch size, 1]\n",
    "        assert len(decoder_input.shape) == 2 and decoder_input.shape[\n",
    "            -1] == 1, f\"decoder_input.shape = {decoder_input.shape}\"\n",
    "        # 确保hidden的形状为[num_layers, batch size],第一次使用的是encoder的hidden\n",
    "        assert len(hidden.shape) == 2, f\"hidden.shape = {hidden.shape}\"\n",
    "        # 确保encoder_outputs的形状为[batch size, sequence length, hidden_dim]\n",
    "        assert len(encoder_outputs.shape) == 3, f\"encoder_outputs.shape = {encoder_outputs.shape}\"\n",
    "\n",
    "        # 使用注意力机制计算上下文向量和注意力分数\n",
    "        # context_vector.shape = [batch size, 1, hidden_dim]\n",
    "        # attention_score.shape = [batch size, 1, sequence length]\n",
    "        context_vector, attention_score = self.attention(\n",
    "            query=hidden, keys=encoder_outputs, values=encoder_outputs, attn_mask=attn_mask)\n",
    "\n",
    "        # 将decoder输入转化为嵌入向量\n",
    "        # embeds.shape = [batch size, 1, embedding_dim]\n",
    "        embeds = self.embedding(decoder_input)\n",
    "        # 将上下文向量与嵌入向量拼接在一起\n",
    "        # embeds.shape = [batch size, 1, embedding_dim + hidden_dim]\n",
    "        embeds = torch.cat([context_vector.unsqueeze(-2), embeds], dim=-1)\n",
    "        # 使用GRU层处理拼接后的输入，生成序列输出和新的隐藏状态\n",
    "        # seq_output.shape = [batch size, 1, hidden_dim]\n",
    "        # hidden.shape = [num_layers, batch size, hidden_dim]\n",
    "        seq_output, hidden = self.gru(embeds)\n",
    "        # 使用dropout层防止过拟合，然后通过全连接层生成每个词的概率分布\n",
    "        # logits.shape = [batch size, 1, vocab_size]\n",
    "        logits = self.fc(self.dropout(seq_output))\n",
    "        # 返回每个词的概率分布、新的隐藏状态和注意力分数\n",
    "        return logits, hidden, attention_score\n"
   ],
   "id": "5b31b40cb4eba779",
   "outputs": [],
   "execution_count": 24
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-03-10T01:31:20.879415Z",
     "start_time": "2025-03-10T01:31:20.871892Z"
    }
   },
   "cell_type": "code",
   "source": [
    "class Sequence2Sequence(nn.Module):\n",
    "    def __init__(\n",
    "            self,\n",
    "            src_vocab_size,  #输入词典大小\n",
    "            trg_vocab_size,  #输出词典大小\n",
    "            encoder_embedding_dim=256,\n",
    "            encoder_hidden_dim=1024,\n",
    "            encoder_num_layers=4,\n",
    "            decoder_embedding_dim=256,\n",
    "            decoder_hidden_dim=1024,\n",
    "            decoder_num_layers=4,\n",
    "            bos_idx=1,\n",
    "            eos_idx=3,\n",
    "            max_length=512,\n",
    "    ):\n",
    "        super(Sequence2Sequence, self).__init__()\n",
    "        self.bos_idx = bos_idx\n",
    "        self.eos_idx = eos_idx\n",
    "        self.max_length = max_length\n",
    "        self.encoder = Encoder(\n",
    "            src_vocab_size,\n",
    "            embedding_dim=encoder_embedding_dim,\n",
    "            hidden_dim=encoder_hidden_dim,\n",
    "            num_layers=encoder_num_layers,\n",
    "        )\n",
    "        self.decoder = Decoder(\n",
    "            trg_vocab_size,\n",
    "            embedding_dim=decoder_embedding_dim,\n",
    "            hidden_dim=decoder_hidden_dim,\n",
    "            num_layers=decoder_num_layers,\n",
    "        )\n",
    "\n",
    "    def forward(self, *, encoder_inputs, decoder_inputs, attn_mask=None):\n",
    "        # 通过编码器处理输入，得到编码器的输出和隐藏状态\n",
    "        encoder_outputs, hidden = self.encoder(encoder_inputs)\n",
    "\n",
    "        # 获取解码器输入的批次大小和序列长度\n",
    "        bs, seq_len = decoder_inputs.shape\n",
    "\n",
    "        # 初始化用于存储每一步输出的logits和注意力分数的列表\n",
    "        logits_list = []\n",
    "        scores_list = []\n",
    "\n",
    "        # 遍历解码器输入的每一个时间步\n",
    "        for i in range(seq_len):\n",
    "            # 每次迭代生成一个时间步的预测，存储在 logits_list 中，并且记录注意力分数（如果有的话）在 scores_list 中，最后将预测的logits和注意力分数拼接并返回。\n",
    "            # 调用解码器，传入当前时间步的输入、上一个时间步的隐藏状态、编码器输出和注意力掩码\n",
    "            logits, hidden, score = self.decoder(\n",
    "                decoder_inputs[:, i:i + 1],  # 当前时间步的输入\n",
    "                hidden[-1],  # 上一个时间步的隐藏状态\n",
    "                encoder_outputs,  # 编码器输出\n",
    "                attn_mask=attn_mask  # 注意力掩码\n",
    "            )\n",
    "\n",
    "            # 将当前时间步的logits和注意力分数分别添加到列表中\n",
    "            logits_list.append(logits)\n",
    "            scores_list.append(score)\n",
    "\n",
    "        # 将所有时间步的logits和注意力分数在时间维度上拼接起来\n",
    "        return torch.cat(logits_list, dim=-2), torch.cat(scores_list, dim=-1)\n",
    "\n",
    "    @torch.no_grad()  # 使用该装饰器表示在此函数中不进行梯度计算，通常用于推理阶段\n",
    "    def infer(self, encoder_input, attn_mask=None):\n",
    "        \"\"\"\n",
    "        推理函数，用于生成序列输出。\n",
    "    \n",
    "        参数:\n",
    "        - encoder_input: 编码器的输入张量，通常是输入序列的嵌入表示。\n",
    "        - attn_mask: 可选的注意力掩码，用于在解码器中屏蔽某些位置。\n",
    "    \n",
    "        返回:\n",
    "        - pred_list: 生成的序列（以列表形式返回）。\n",
    "        - score_list: 每个时间步的注意力分数（拼接后的张量）。\n",
    "        \"\"\"\n",
    "\n",
    "        # 通过编码器处理输入，得到编码器的输出和隐藏状态\n",
    "        encoder_outputs, hidden = self.encoder(encoder_input)\n",
    "\n",
    "        # 初始化解码器的输入，通常为序列的开始符号（bos_idx），并将其转换为适合的形状和数据类型\n",
    "        decoder_input = torch.Tensor([self.bos_idx]).reshape(1, 1).to(dtype=torch.int64)\n",
    "\n",
    "        # 用于存储生成的序列和每个时间步的注意力分数\n",
    "        pred_list = []\n",
    "        score_list = []\n",
    "\n",
    "        # 开始循环生成序列，最多生成max_length个时间步\n",
    "        for _ in range(self.max_length):\n",
    "            # 通过解码器计算当前时间步的输出logits、隐藏状态和注意力分数\n",
    "            logits, hidden, score = self.decoder(\n",
    "                decoder_input,\n",
    "                hidden[-1],  # 使用编码器最后一个隐藏层状态作为解码器的初始隐藏状态\n",
    "                encoder_outputs,\n",
    "                attn_mask=attn_mask\n",
    "            )\n",
    "\n",
    "            # 从logits中获取预测的下一个token（即概率最大的token）\n",
    "            decoder_pred = logits.argmax(dim=-1)\n",
    "\n",
    "            # 将预测的token作为下一个时间步的输入\n",
    "            decoder_input = decoder_pred\n",
    "\n",
    "            # 将预测的token添加到生成的序列列表中\n",
    "            pred_list.append(decoder_pred.reshape(-1).item())\n",
    "\n",
    "            # 将当前时间步的注意力分数添加到分数列表中\n",
    "            score_list.append(score)\n",
    "\n",
    "            # 如果预测的token是序列的结束符号（eos_idx），则停止生成\n",
    "            if decoder_pred == self.eos_idx:\n",
    "                break\n",
    "\n",
    "        # 返回生成的序列和拼接后的注意力分数\n",
    "        return pred_list, torch.cat(score_list, dim=-1)\n",
    "\n"
   ],
   "id": "3e6ad0bc4655318f",
   "outputs": [],
   "execution_count": 48
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "# 训练",
   "id": "731624f964d3c321"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "### 定义损失函数",
   "id": "84d8bf923ae31230"
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-03-10T00:41:34.132466Z",
     "start_time": "2025-03-10T00:41:34.126965Z"
    }
   },
   "cell_type": "code",
   "source": [
    "def cross_entropy_with_padding(logits, labels, padding_mask=None):\n",
    "    \"\"\"\n",
    "    计算带有填充的交叉熵损失。\n",
    "    参数:\n",
    "    - logits: 预测的logits张量，形状为 [batch size, sequence length, num of classes]。\n",
    "    - labels: 真实标签张量，形状为 [batch size, sequence length]。\n",
    "    - padding_mask: 可选的填充掩码，形状为 [batch size, sequence length]。\n",
    "    返回:\n",
    "    - loss: 交叉熵损失。\n",
    "    \"\"\"\n",
    "    # 获取logits的形状：bs为batch size，seq_len为序列长度，nc为类别数\n",
    "    bs, seq_len, nc = logits.shape\n",
    "\n",
    "    # 使用交叉熵损失函数计算损失。logits被重塑为(bs * seq_len, nc)，labels被展平为(-1)。\n",
    "    # reduce=False表示不直接对损失进行求和或平均，返回每个样本的损失值\n",
    "    loss = F.cross_entropy(logits.reshape(bs * seq_len, nc), labels.reshape(-1),\n",
    "                           reduce=False)\n",
    "\n",
    "    # 如果padding_mask为None，即没有填充掩码，则直接对损失求平均\n",
    "    if padding_mask is None:\n",
    "        loss = loss.mean()\n",
    "    else:\n",
    "        # 如果提供了 padding_mask，则将填充部分的损失去除后计算有效损失的均值。首先，通过将 padding_mask reshape 成一维张量，并取 1 减去得到填充掩码。这样填充部分的掩码值变为 1，非填充部分变为 0。将损失张量与填充掩码相乘，这样填充部分的损失就会变为 0。然后，计算非填充部分的损失和（sum）以及非填充部分的掩码数量（sum）作为有效损失的均值计算。(因为上面我们设计的mask的token是0，所以这里是1-padding_mask)\n",
    "        # 如果有填充掩码，将padding_mask展平为(-1)\n",
    "        padding_mask = 1 - padding_mask.reshape(-1)\n",
    "\n",
    "        # 将损失与padding_mask相乘，以忽略填充部分的损失，然后求和并除以有效部分的总和，得到加权平均损失\n",
    "        loss = torch.mul(loss, padding_mask).sum() / padding_mask.sum()\n",
    "    return loss"
   ],
   "id": "58456d97e562cfd8",
   "outputs": [],
   "execution_count": 28
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-03-10T00:41:50.403840Z",
     "start_time": "2025-03-10T00:41:43.780337Z"
    }
   },
   "cell_type": "code",
   "source": [
    "from torch.utils.tensorboard import SummaryWriter\n",
    "\n",
    "\n",
    "class TensorBoardCallback:\n",
    "    def __init__(self, log_dir, flush_secs=10):\n",
    "        \"\"\"\n",
    "        Args:\n",
    "            log_dir (str): dir to write log.\n",
    "            flush_secs (int, optional): write to dsk each flush_secs seconds. Defaults to 10.\n",
    "        \"\"\"\n",
    "        self.writer = SummaryWriter(log_dir=log_dir, flush_secs=flush_secs)\n",
    "\n",
    "    def draw_model(self, model, input_shape):\n",
    "        self.writer.add_graph(model, input_to_model=torch.randn(input_shape))\n",
    "\n",
    "    def add_loss_scalars(self, step, loss, val_loss):\n",
    "        self.writer.add_scalars(\n",
    "            main_tag=\"training/loss\",\n",
    "            tag_scalar_dict={\"loss\": loss, \"val_loss\": val_loss},\n",
    "            global_step=step,\n",
    "        )\n",
    "\n",
    "    def add_acc_scalars(self, step, acc, val_acc):\n",
    "        self.writer.add_scalars(\n",
    "            main_tag=\"training/accuracy\",\n",
    "            tag_scalar_dict={\"accuracy\": acc, \"val_accuracy\": val_acc},\n",
    "            global_step=step,\n",
    "        )\n",
    "\n",
    "    def add_lr_scalars(self, step, learning_rate):\n",
    "        self.writer.add_scalars(\n",
    "            main_tag=\"training/learning_rate\",\n",
    "            tag_scalar_dict={\"learning_rate\": learning_rate},\n",
    "            global_step=step,\n",
    "\n",
    "        )\n",
    "\n",
    "    def __call__(self, step, **kwargs):\n",
    "        # add loss\n",
    "        loss = kwargs.pop(\"loss\", None)\n",
    "        val_loss = kwargs.pop(\"val_loss\", None)\n",
    "        if loss is not None and val_loss is not None:\n",
    "            self.add_loss_scalars(step, loss, val_loss)\n",
    "        # add acc\n",
    "        acc = kwargs.pop(\"acc\", None)\n",
    "        val_acc = kwargs.pop(\"val_acc\", None)\n",
    "        if acc is not None and val_acc is not None:\n",
    "            self.add_acc_scalars(step, acc, val_acc)\n",
    "        # add lr\n",
    "        learning_rate = kwargs.pop(\"lr\", None)\n",
    "        if learning_rate is not None:\n",
    "            self.add_lr_scalars(step, learning_rate)\n"
   ],
   "id": "2647a0572138fc96",
   "outputs": [],
   "execution_count": 29
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-03-10T00:41:50.408568Z",
     "start_time": "2025-03-10T00:41:50.403840Z"
    }
   },
   "cell_type": "code",
   "source": [
    "class SaveCheckpointsCallback:\n",
    "    def __init__(self, save_dir, save_step=5000, save_best_only=True):\n",
    "        \"\"\"\n",
    "        Save checkpoints each save_epoch epoch.\n",
    "        We save checkpoint by epoch in this implementation.\n",
    "        Usually, training scripts with pytorch evaluating model and save checkpoint by step.\n",
    "\n",
    "        Args:\n",
    "            save_dir (str): dir to save checkpoint\n",
    "            save_epoch (int, optional): the frequency to save checkpoint. Defaults to 1.\n",
    "            save_best_only (bool, optional): If True, only save the best model or save each model at every epoch.\n",
    "        \"\"\"\n",
    "        self.save_dir = save_dir\n",
    "        self.save_step = save_step\n",
    "        self.save_best_only = save_best_only\n",
    "        self.best_metrics = - np.inf\n",
    "\n",
    "        # mkdir\n",
    "        if not os.path.exists(self.save_dir):\n",
    "            os.mkdir(self.save_dir)\n",
    "\n",
    "    def __call__(self, step, state_dict, metric=None):\n",
    "        if step % self.save_step > 0:\n",
    "            return\n",
    "\n",
    "        if self.save_best_only:\n",
    "            assert metric is not None\n",
    "            if metric >= self.best_metrics:\n",
    "                # save checkpoints\n",
    "                torch.save(state_dict, os.path.join(self.save_dir, \"best.ckpt\"))\n",
    "                # update best metrics\n",
    "                self.best_metrics = metric\n",
    "        else:\n",
    "            torch.save(state_dict, os.path.join(self.save_dir, f\"{step}.ckpt\"))\n",
    "\n"
   ],
   "id": "2c6568a6e841055",
   "outputs": [],
   "execution_count": 30
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-03-10T00:41:54.571836Z",
     "start_time": "2025-03-10T00:41:54.566783Z"
    }
   },
   "cell_type": "code",
   "source": [
    "class EarlyStopCallback:\n",
    "    def __init__(self, patience=5, min_delta=0.01):\n",
    "        \"\"\"\n",
    "\n",
    "        Args:\n",
    "            patience (int, optional): Number of epochs with no improvement after which training will be stopped.. Defaults to 5.\n",
    "            min_delta (float, optional): Minimum change in the monitored quantity to qualify as an improvement, i.e. an absolute\n",
    "                change of less than min_delta, will count as no improvement. Defaults to 0.01.\n",
    "        \"\"\"\n",
    "        self.patience = patience\n",
    "        self.min_delta = min_delta\n",
    "        self.best_metric = - np.inf\n",
    "        self.counter = 0\n",
    "\n",
    "    def __call__(self, metric):\n",
    "        if metric >= self.best_metric + self.min_delta:\n",
    "            # update best metric\n",
    "            self.best_metric = metric\n",
    "            # reset counter\n",
    "            self.counter = 0\n",
    "        else:\n",
    "            self.counter += 1\n",
    "\n",
    "    @property\n",
    "    def early_stop(self):\n",
    "        return self.counter >= self.patience\n"
   ],
   "id": "8d7abd5eea9d0e7f",
   "outputs": [],
   "execution_count": 31
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "### 训练模型",
   "id": "e18e62ae569a2ec8"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "@torch.no_grad()\n",
    "def evaluating(model, dataloader, loss_fct):\n",
    "    loss_list = []\n",
    "    for batch in dataloader:\n",
    "        encoder_inputs = batch[\"encoder_inputs\"]\n",
    "        encoder_inputs_mask = batch[\"encoder_inputs_mask\"]\n",
    "        decoder_inputs = batch[\"decoder_inputs\"]\n",
    "        decoder_labels = batch[\"decoder_labels\"]\n",
    "        decoder_labels_mask = batch[\"decoder_labels_mask\"]\n",
    "\n",
    "        # 前向计算\n",
    "        logits, _ = model(\n",
    "            encoder_inputs=encoder_inputs,\n",
    "            decoder_inputs=decoder_inputs,\n",
    "            attn_mask=encoder_inputs_mask\n",
    "        )  #model就是seq2seq模型\n",
    "        loss = loss_fct(logits, decoder_labels, padding_mask=decoder_labels_mask)  # 验证集损失\n",
    "        loss_list.append(loss.cpu().item())\n",
    "\n",
    "    return np.mean(loss_list)\n"
   ],
   "id": "cbacf156815ad6e5",
   "outputs": [],
   "execution_count": 32
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-03-10T00:44:33.002654Z",
     "start_time": "2025-03-10T00:44:32.997827Z"
    }
   },
   "cell_type": "code",
   "source": [
    "# 训练\n",
    "def training(\n",
    "        model,\n",
    "        train_loader,\n",
    "        val_loader,\n",
    "        epoch,\n",
    "        loss_fct,\n",
    "        optimizer,\n",
    "        tensorboard_callback=None,\n",
    "        save_ckpt_callback=None,\n",
    "        early_stop_callback=None,\n",
    "        eval_step=500,\n",
    "):\n",
    "    record_dict = {\n",
    "        \"train\": [],\n",
    "        \"val\": []\n",
    "    }\n",
    "\n",
    "    global_step = 1\n",
    "    model.train()  # 切换到训练模式\n",
    "    with tqdm(total=epoch * len(train_loader)) as pbar:\n",
    "        for epoch_id in range(epoch):\n",
    "            # training\n",
    "            for batch in train_loader:\n",
    "                encoder_inputs = batch[\"encoder_inputs\"]\n",
    "                encoder_inputs_mask = batch[\"encoder_inputs_mask\"]\n",
    "                decoder_inputs = batch[\"decoder_inputs\"]\n",
    "                decoder_labels = batch[\"decoder_labels\"]\n",
    "                decoder_labels_mask = batch[\"decoder_labels_mask\"]\n",
    "\n",
    "                # 梯度清空\n",
    "                optimizer.zero_grad()\n",
    "\n",
    "                # 前向计算\n",
    "                logits, _ = model(\n",
    "                    encoder_inputs=encoder_inputs,\n",
    "                    decoder_inputs=decoder_inputs,\n",
    "                    attn_mask=encoder_inputs_mask\n",
    "                )\n",
    "                loss = loss_fct(logits, decoder_labels, padding_mask=decoder_labels_mask)\n",
    "\n",
    "                # 梯度回传\n",
    "                loss.backward()\n",
    "\n",
    "                # 调整优化器，包括学习率的变动等\n",
    "                optimizer.step()\n",
    "\n",
    "                loss = loss.cpu().item()\n",
    "                # record\n",
    "                record_dict[\"train\"].append({\n",
    "                    \"loss\": loss, \"step\": global_step\n",
    "                })\n",
    "\n",
    "                # evaluating\n",
    "                if global_step % eval_step == 0:\n",
    "                    model.eval()  # 切换到验证模式\n",
    "                    val_loss = evaluating(model, val_loader, loss_fct)\n",
    "                    record_dict[\"val\"].append({\n",
    "                        \"loss\": val_loss, \"step\": global_step\n",
    "                    })\n",
    "                    model.train()  # 切换到训练模式\n",
    "\n",
    "                    # 1. 使用 tensorboard 可视化\n",
    "                    if tensorboard_callback is not None:\n",
    "                        tensorboard_callback(\n",
    "                            global_step,\n",
    "                            loss=loss, val_loss=val_loss,\n",
    "                            lr=optimizer.param_groups[0][\"lr\"],\n",
    "                        )\n",
    "\n",
    "                    # 2. 保存模型权重 save model checkpoint\n",
    "                    if save_ckpt_callback is not None:\n",
    "                        save_ckpt_callback(global_step, model.state_dict(), metric=-val_loss)\n",
    "\n",
    "                    # 3. 早停 Early Stop\n",
    "                    if early_stop_callback is not None:\n",
    "                        early_stop_callback(-val_loss)\n",
    "                        if early_stop_callback.early_stop:\n",
    "                            print(f\"Early stop at epoch {epoch_id} / global_step {global_step}\")\n",
    "                            return record_dict\n",
    "\n",
    "                # udate step\n",
    "                global_step += 1\n",
    "                pbar.update(1)\n",
    "            pbar.set_postfix({\"epoch\": epoch_id, \"loss\": loss, \"val_loss\": val_loss})  # 更新进度条\n",
    "\n",
    "    return record_dict"
   ],
   "id": "fb1dc986f4848e9",
   "outputs": [],
   "execution_count": 35
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-03-10T01:19:19.998223Z",
     "start_time": "2025-03-10T00:44:52.770244Z"
    }
   },
   "cell_type": "code",
   "source": [
    "epoch = 20\n",
    "batch_size = 64\n",
    "\n",
    "model = Sequence2Sequence(src_vocab_size=len(src_word2idx), trg_vocab_size=len(trg_word2idx))\n",
    "train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True, collate_fn=collate_fct)\n",
    "test_dl = DataLoader(test_ds, batch_size=batch_size, shuffle=False, collate_fn=collate_fct)\n",
    "\n",
    "# 1. 定义损失函数 采用交叉熵损失\n",
    "loss_fct = cross_entropy_with_padding\n",
    "# 2. 定义优化器 采用 adam\n",
    "# Optimizers specified in the torch.optim package\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n",
    "\n",
    "# 1. tensorboard 可视化\n",
    "if not os.path.exists(\"runs\"):\n",
    "    os.mkdir(\"runs\")\n",
    "exp_name = \"translate-seq2seq\"\n",
    "tensorboard_callback = TensorBoardCallback(f\"runs/{exp_name}\")\n",
    "# tensorboard_callback.draw_model(model, [1, MAX_LENGTH])\n",
    "# 2. save best\n",
    "if not os.path.exists(\"checkpoints\"):\n",
    "    os.makedirs(\"checkpoints\")\n",
    "save_ckpt_callback = SaveCheckpointsCallback(\n",
    "    f\"checkpoints/{exp_name}\", save_step=200, save_best_only=True)\n",
    "# 3. early stop\n",
    "early_stop_callback = EarlyStopCallback(patience=5)\n",
    "\n",
    "model = model.to(device)\n",
    "\n",
    "record = training(\n",
    "    model,\n",
    "    train_dl,\n",
    "    test_dl,\n",
    "    epoch,\n",
    "    loss_fct,\n",
    "    optimizer,\n",
    "    tensorboard_callback=None,\n",
    "    save_ckpt_callback=save_ckpt_callback,\n",
    "    early_stop_callback=early_stop_callback,\n",
    "    eval_step=200\n",
    ")"
   ],
   "id": "566862a44462482b",
   "outputs": [
    {
     "data": {
      "text/plain": [
       "  0%|          | 0/33520 [00:00<?, ?it/s]"
      ],
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "0ef65c082a664c74ad2c30cd06d4aba0"
      }
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Program Files\\Python312\\Lib\\site-packages\\torch\\nn\\_reduction.py:42: UserWarning: size_average and reduce args will be deprecated, please use reduction='none' instead.\n",
      "  warnings.warn(warning.format(ret))\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Early stop at epoch 5 / global_step 9000\n"
     ]
    }
   ],
   "execution_count": 36
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-03-10T01:19:20.004967Z",
     "start_time": "2025-03-10T01:19:20.000760Z"
    }
   },
   "cell_type": "code",
   "source": [
    "#计算模型参数量\n",
    "sum(i[1].numel() for i in model.named_parameters())"
   ],
   "id": "ad825758963c35f1",
   "outputs": [
    {
     "data": {
      "text/plain": [
       "72968118"
      ]
     },
     "execution_count": 37,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 37
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-03-10T01:19:20.084452Z",
     "start_time": "2025-03-10T01:19:20.006008Z"
    }
   },
   "cell_type": "code",
   "source": [
    "plt.plot([i[\"step\"] for i in record[\"train\"]], [i[\"loss\"] for i in record[\"train\"]], label=\"train\")\n",
    "plt.plot([i[\"step\"] for i in record[\"val\"]], [i[\"loss\"] for i in record[\"val\"]], label=\"val\")\n",
    "plt.grid()\n",
    "plt.show()"
   ],
   "id": "ab369844c369ef69",
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ],
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAh8AAAGdCAYAAACyzRGfAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjAsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvlHJYcgAAAAlwSFlzAAAPYQAAD2EBqD+naQAANyRJREFUeJzt3Qd4VGXa//F70hMgoRNK6FWQoiKCiCAdV0VY17YuuJZX14694Iqui+vu3xUVcd13hd1VbO+KrggoRUEUVJQiFjrSQUoSICSknP91P5MzmUkjCZNnJsz3c13DnJk5OXPmPCHnN087HsdxHAEAALAkytYbAQAAKMIHAACwivABAACsInwAAACrCB8AAMAqwgcAALCK8AEAAKwifAAAAKtiJMwUFBTIrl27pE6dOuLxeEK9OwAAoAJ0ztLDhw9Ls2bNJCoqqmaFDw0eaWlpod4NAABQBdu3b5cWLVrUrPChNR7uzicnJwd127m5ufLRRx/JsGHDJDY2NqjbRuVRHuGF8ggvlEf4oUzKl5mZaSoP3PN4jQofblOLBo/qCB9JSUlmu/zihB7lEV4oj/BCeYQfyqRiKtJlgg6nAADAKsIHAACwivABAACsInwAAACrCB8AAMAqwgcAALCK8AEAAKwifAAAAKsIHwAAwCrCBwAAsIrwAQAArCJ8AAAAqwgflfTWV9vl8037Q70bAADUWGF3Vdtwtnp7utz3nzVmeetTF4Z6dwAAqJGo+aiEHYeOhXoXAACo8QgfAADAKsIHAACwivABAACsInwAAACrCB8AAMAqwgcAALCK8AEAAKwifAAAAKsIHwAAwCrCBwAAsIrwAQAArCJ8AAAAqwgfAADAKsIHAACwivABAACsInwAAACrCB8AAMAqwgcAALCK8AEAAKwifAAAAKsIHwAAwCrCBwAAsIrwAQAArCJ8AAAAqwgfAADAKsIHAACwivABAACsInwAAACrCB8AAMAqwgcAALCK8AEAAKwifAAAAKsIHwAAwCrCBwAACN/wMXnyZOndu7fUqVNHGjduLKNHj5Z169YFrDNw4EDxeDwBt5tuuinY+w0AACIhfCxevFhuueUWWb58ucyfP19yc3Nl2LBhcvTo0YD1brjhBtm9e7fv9vTTTwd7vwEAQA0VU5mV582bF/B4xowZpgbk66+/lgEDBvieT0pKktTU1ODtJQAAiMzwUVxGRoa5r1+/fsDzr732mrz66qsmgFx00UUyceJEE0hKk5OTY26uzMxMc6+1KnoLJnd7Vd1ufn5eiW0hdOWB4KI8wgvlEX4ok/JV5rh4HMdxpAoKCgrk4osvlvT0dFm6dKnv+ZdffllatWolzZo1kzVr1sj9998vZ599trzzzjulbuexxx6TSZMmlXh+5syZZQaWUFl5wCMz1keb5Sl9i4IIAACRLisrS6666ipTMZGcnFw94ePmm2+WuXPnmuDRokWLMtdbtGiRDB48WDZu3Cjt2rWrUM1HWlqa7N+//4Q7X5VUpn1Vhg4dKrGxsZX++blr98jtb64xyxueGBbUfYtEJ1seCC7KI7xQHuGHMimfnr8bNmxYofBRpWaXW2+9VWbPni1LliwpN3ioPn36mPuywkd8fLy5FacFW12FW9VtR0cXHS5+8YKnOssalUd5hBfKI/xQJqWrzDGpVPjQSpLbbrtNZs2aJZ988om0adPmhD+zatUqc9+0adPKvBUAADhFVSp86DBb7Yvx3nvvmbk+9uzZY55PSUmRxMRE2bRpk3l91KhR0qBBA9Pn46677jIjYbp3715dnwEAAJyq4WPatGm+icT8TZ8+XcaPHy9xcXGyYMECefbZZ83cH9p3Y+zYsfLII48Ed68BAECNVelml/Jo2NCJyAAAAMrCtV0AAIBVhA8AAGAV4QMAAFhF+AAAAFYRPgAAgFWEDwAAYBXhAwAAWEX4AAAAVhE+AACAVYQPAABgFeEDAABYRfgAAABWET4AAIBVhA8AAGAV4QMAAFhF+AAAAFYRPgAAgFWEDwAAYBXhAwAAWEX4AAAAVhE+AACAVYQPAABgFeEDAABYRfgAAABWET4AAIBVhA8AAGAV4QMAAFhF+AAAAFYRPgAAgFWEDwAAYBXhAwAAWEX4AAAAVhE+AACAVYQPAABgFeEDAABYRfgAAABWET4AAIBVhA8AAGAV4QMAAFhF+AAAAFYRPgAAgFWEDwAAYBXhAwAAWEX4AAAAVhE+AACAVYQPAABgFeEDAABYRfgAAABWET4AAIBVhA8AAGAV4QMAAFhF+AAAAFYRPgAAgFWEDwAAYBXhAwAAWEX4AAAAVhE+AACAVYQPAAAQvuFj8uTJ0rt3b6lTp440btxYRo8eLevWrQtYJzs7W2655RZp0KCB1K5dW8aOHSt79+4N9n4DAIBICB+LFy82wWL58uUyf/58yc3NlWHDhsnRo0d969x1113y/vvvy9tvv23W37Vrl4wZM6Y69h0AANRAMZVZed68eQGPZ8yYYWpAvv76axkwYIBkZGTIP/7xD5k5c6ZccMEFZp3p06dLly5dTGA555xzgrv3AADg1A4fxWnYUPXr1zf3GkK0NmTIkCG+dTp37iwtW7aUZcuWlRo+cnJyzM2VmZlp7nU7egsmd3tV3W5+fl6JbSF05YHgojzCC+URfiiT8lXmuFQ5fBQUFMidd94p5557rnTr1s08t2fPHomLi5O6desGrNukSRPzWln9SCZNmlTi+Y8++kiSkpKkOmiTUVWsPOARkWizPGfOnCDvVeSqanmgelAe4YXyCD+USemysrKk2sOH9v1Yu3atLF26VE7Ggw8+KBMmTAio+UhLSzN9SZKTkyXYqUx/aYYOHSqxsbGV/nnP2j0yY/0aszxq1Kig7lskOtnyQHBRHuGF8gg/lEn53JaLagsft956q8yePVuWLFkiLVq08D2fmpoqx48fl/T09IDaDx3toq+VJj4+3tyK04KtrsKt6rajo4sOF794wVOdZY3KozzCC+URfiiT0lXmmFRqtIvjOCZ4zJo1SxYtWiRt2rQJeP3MM880b75w4ULfczoUd9u2bdK3b9/KvBUAADhFxVS2qUVHsrz33ntmrg+3H0dKSookJiaa++uuu840o2gnVG02ue2220zwYKQLAACodPiYNm2auR84cGDA8zqcdvz48Wb5r3/9q0RFRZnJxXQUy/Dhw+XFF1/kaAMAgMqHD212OZGEhASZOnWquQEAABTHtV0AAIBVhA8AAGAV4QMAAFhF+AAAAFYRPgAAgFWEDwAAYBXhAwAAWEX4qASPXtQWAACcFMJHJVRgjjUAAHAChA8AAGAV4QMAAFhF+AAAAFYRPgAAgFWEDwAAYBXhAwAAWEX4AAAAVhE+AACAVYQPAABgFeEDAABYRfgAAABWET4AAIBVhA8AAGAV4QMAAFhF+AAAAFYRPgAAgFWEDwAAYBXhAwAAWEX4AAAAVhE+AACAVYQPAABgFeEDAABYRfgAAABWET4qweMJ9R4AAFDzET4qwXFCvQcAANR8hA8AAGAV4QMAAFhF+AAAAFYRPgAAgFWEDwAAYBXhAwAAWEX4AAAAVhE+AACAVYQPAABgFeEDAABYRfgAAABWET4AAIBVhA8AAGAV4QMAAFhF+AAAAFYRPgAAgFWEDwAAYBXhAwAAWEX4AAAAVhE+AACAVYQPAABgFeEDAABYRfioBI8n1HsAAEAEho8lS5bIRRddJM2aNROPxyPvvvtuwOvjx483z/vfRowYEcx9BgAAkRQ+jh49Kj169JCpU6eWuY6Gjd27d/tur7/+upwKHCfUewAAQM0XU9kfGDlypLmVJz4+XlJTU09mvwAAwCmqWvp8fPLJJ9K4cWPp1KmT3HzzzXLgwIHqeBsAABAJNR8nok0uY8aMkTZt2simTZvkoYceMjUly5Ytk+jo6BLr5+TkmJsrMzPT3Ofm5ppbMLnbq+p28/PzSmwLoSsPBBflEV4oj/BDmZSvMsfF4zhV78mgnUlnzZolo0ePLnOdzZs3S7t27WTBggUyePDgEq8/9thjMmnSpBLPz5w5U5KSkiScrDzgkRnrvQFqSt+iIAIAQKTLysqSq666SjIyMiQ5OdluzUdxbdu2lYYNG8rGjRtLDR8PPvigTJgwIaDmIy0tTYYNG3bCna9KKps/f74MHTpUYmNjK/3znrV7ZMb6NWZ51KhRQd23SHSy5YHgojzCC+URfiiT8rktFxVR7eFjx44dps9H06ZNy+ycqrfitGCrq3Cruu3o6KLDxS9e8FRnWaPyKI/wQnmEH8qkdJU5JpUOH0eOHDG1GK4tW7bIqlWrpH79+uamTShjx441o120z8d9990n7du3l+HDh1f2rQAAwCmo0uFjxYoVMmjQIN9jt8lk3LhxMm3aNFmzZo3885//lPT0dDMRmTafPPHEE6XWbgAAgMhT6fAxcOBAKa+P6ocffniy+wQAAE5hXNsFAABYRfgAAABWET4AAIBVhA8AAGAV4QMAAFhF+AAAAFYRPgAAgFWEDwAAYBXhAwAAWEX4AAAAVhE+AACAVYQPAABgFeEDAABYRfioBI8n1HsAAEDNR/gAAABWET4qwXFCvQcAANR8hA8AAGAV4QMAAFhF+AAAAFYRPgAAgFWEDwAAYBXhAwAAWEX4AAAAVhE+AACAVYQPAABgFeEDAABYRfgAAABWET4AAIBVhA8AAGAV4QMAAFhF+AAAAFYRPgAAgFWEDwAAYBXhAwAAWEX4AAAAVhE+AACAVYQPAABgFeGjEjyeUO8BAAA1H+EDAABYRfioBMcJ9R4AAFDzET4AAIBVhA8AAGAV4QMAAFhF+AAAAFYRPgAAgFWEDwAAYBXhAwAAWEX4AAAAVhE+AACAVYQPAABgFeEDAABYRfgAAABWET4AAIBVhA8AAGAV4QMAAFhF+AAAAFYRPgAAgFWEDwAAEN7hY8mSJXLRRRdJs2bNxOPxyLvvvhvwuuM48uijj0rTpk0lMTFRhgwZIhs2bAjmPgMAgEgKH0ePHpUePXrI1KlTS3396aeflueee05eeukl+eKLL6RWrVoyfPhwyc7ODsb+AgCAGi6msj8wcuRIcyuN1no8++yz8sgjj8gll1xinvvXv/4lTZo0MTUkV1xxxcnvMQAAiKzwUZ4tW7bInj17TFOLKyUlRfr06SPLli0rNXzk5OSYmyszM9Pc5+bmmlswudur6nbz8/NKbAuhKw8EF+URXiiP8EOZlK8yxyWo4UODh9KaDn/62H2tuMmTJ8ukSZNKPP/RRx9JUlKSVIf58+dX6edWHfCISLRZnjNnTpD3KnJVtTxQPSiP8EJ5hB/KpHRZWVkSkvBRFQ8++KBMmDAhoOYjLS1Nhg0bJsnJyUFPZfpLM3ToUImNja30z3vW7pHp69eY5VGjRgV13yLRyZYHgovyCC+UR/ihTMrntlxYDx+pqanmfu/evWa0i0sf9+zZs9SfiY+PN7fitGCrq3Cruu3o6KLDxS9e8FRnWaPyKI/wQnmEH8qkdJU5JkGd56NNmzYmgCxcuDAgCemol759+wbzrQAAQA1V6ZqPI0eOyMaNGwM6ma5atUrq168vLVu2lDvvvFP+8Ic/SIcOHUwYmThxopkTZPTo0cHedwAAEAnhY8WKFTJo0CDfY7e/xrhx42TGjBly3333mblAbrzxRklPT5f+/fvLvHnzJCEhIbh7DgAAIiN8DBw40MznURad9fTxxx83NwAAgIi9tsuW/UflVy9/Ic9/FzEfGQCAsBTyoba2RHs8snJ7hsRF6VwdAAAgVCKmGiAhzvtRcwtCvScAAES2qEiq+VCOeMrtswIAAKpX5IQPv+aW/ALCBwAAoUL4AAAAVkVM+IgqbHZRZA8AAEInYsKHX/YQR0gfAACESkTWfNDfFACA0ImY8OGPZhcAAEInIptdtOEFAACERsSED5pdAAAIDxETPvwrPmh2AQAgdCKz5oNmFwAAQiYi+3xQ8wEAQOhEUPjwn+iD9AEAQKhETPhQbv4gegAAEDqRFT5OstklcLguAACoiogKH26nU4dmFwAAQiaiwodbc0GHUwAAQieiwsfJosIEAICTF1Hhg2YXAABCL6LCB80uAACEXmSFj8J7ZjgFACB0IrTZJdR7AgBA5Iqo8OFWfRA+AAAIncis+aDZBQCAkInMGU4LQrwjAABEsMgKH1zbBQCAkIuo8ME8HwAAhF5EhQ8X2QMAgNCJ0GYX0gcAAKESkc0uzHAKAEDoROYMp4QPAABCJqLCB/N8AAAQehEVPpjhFACA0Iuo8EGzCwAAoRdR4YNmFwAAQi8ih9oy2gUAgNCJsPDBDKcAAIRaZIWPwnuyBwAAoRNZ4YMLywEAEHIRFT62HTxm7rfsP3pS4QUAAFRdRIUP18PvfR/qXQAAIGJFZPjIZ7gLAAAhE5Hho6roqAoAwMkjfAAAAKsIHwAAwCrCBwAAsCqiwkd8TER9XAAAwlJEnY1jopioAwCAUIuo8NEptU6odwEAgIgXUeFj0kVdfMtcXA4AgNCIqPDRvG6ib/mRd9fKseP5lfr5rQeqNi07AACI0PCRFBftW37ti23ywDtrKvXzf/5wXTXsFQAAkSWiwkd0sQ6n763aJXn5BbIr3XvBOQAAUP0iKnyUZuBfPpF+Ty2STzf8HOpdAQAgIkR8+NhxyFvrMWXBhlDvCgAAESHo4eOxxx4Tj8cTcOvcubOEu72Hs0O9CwAARISY6tho165dZcGCBUVvElMtb1MlE3vlyRMrS+7P9oP0+wAAwIZqSQUaNlJTUyUcNUwo+7WPf9wngzo3rtB2snPzJSG2aPQMAAAIYfjYsGGDNGvWTBISEqRv374yefJkadmyZanr5uTkmJsrMzPT3Ofm5ppbMLnbm35NT7n236tKvH7tjK/M/Vs3nC29WtY1E5F1fHS+eW7t74cErDvhzZUy5fIeQd2/SOOWR7DLGVVDeYQXyiP8UCblq8xx8ThBnupz7ty5cuTIEenUqZPs3r1bJk2aJDt37pS1a9dKnTp1Su0jousUN3PmTElKSpLqoJ/4zuXl5660Wo5sP1o0NLdbvQJZeyiwi8y93fOkRa1q2UUAAGqUrKwsueqqqyQjI0OSk5Ptho/i0tPTpVWrVvLMM8/IddddV6Gaj7S0NNm/f/8Jd74qqWz+/PkydOhQmTRnvby5YudJb3PJPQPkeF6BvLh4szw8spMkJ8aWut70z3+S/Udy5N5hHU/6PU8V/uURG1v6cYM9lEd4oTzCD2VSPj1/N2zYsELho9p7gtatW1c6duwoGzduLPX1+Ph4cytOC7a6Cle3+6df9pQ9mcdl8fqTm99jwF+W+JbfWblLnhpzujw461tTuzL3jvPk4NHjEuXxyB/nemdHjY+NkTsGd5CY6CjZsv+opNVLNMuRrDrLGpVHeYQXyiP8UCalq8wxqfbwoU0wmzZtkmuuuUbCzVNjT5e+kxcFdZsPvPOtb3nklE9LvP78oo0yd+0eOadtfXl1+TYZ06u5PHN5z6DuAwAA4SzoX7nvueceWbx4sWzdulU+//xzufTSSyU6OlquvPJKCTdNUxLly4cGyy+6N/V71pGhUSukiRystvfduO+ICR7qnZU7pfUDH8i/l20NWMe/NWzr/qNy68xv5Pp/fiUrtx0yI22Kyy9w5JWlW2Ttzoxq228AAIIh6DUfO3bsMEHjwIED0qhRI+nfv78sX77cLIejxskJ8sJVZ8jsNR+Yx2meffL3uGfM8i6nvqwqaC+rCtrJyoIO8q3TRrKlZBNRMEx87ztzU5ed2ULe/nqHWV5y7yAzBbxrwQ/7zP0Ht/eXrs1SfM+/880OeXz292Z561MXSkGBI7kFBRIfEzgcWIPL/366WYac1kQ6p3rb5HTdo8fzpE5CxavM/jp/vcxes0v+c3M/qZsUd1KfPdJp0Pxxz2Fp37i2xNbwJjjt19SgVpyZXBAArIWPN954Q2oiPWF/vG6fPDvjdfmhoKV09GyXZp6D0iz6SxkV/aVZJ8+JknVOmgkkK532srKgvWx2mooT5AokN3ioAX/+uNR1LnxuqVzfv41cfU4rSYyNlm/9ajzufmu1fLPtkOlT0rB2nPzlsh7SuE6CTFu8yZzoZq/ZLX/5aL2M69tK7h/ZWe54Y5XM/36vLJgwQNo3LhqRtDvjmPx0IEuWrP9ZOjSpLS3qJUnv1vXNa1MWeqejf/jdtfLClb3Msv8JZ19mtsz5dreMObOFJFci1ESiGZ9vlUnvfy8ju6XKtF+fKTXV619ukwff+VZuHdRe7hneKdS7AyCMhc/Uo2HgjJb1ZLXTXkYef0qSJFu6R22WPnGbpUv+BukVtUGaeNKlq+cn6Rr1k1wtC83PZDnxssFpLusLWphgst5pYZb3iJ6kq/fb3/8u3WJuxf3nm6Lwsv/IcRk/3Tt/SXH/XPaTCS3fbEs3j4c8U9R5Vr+Fa/NQcbXioqVJctFMbR+s2W1uSvuxvHb9OebqwVf+fbls+vmo/Hv5TzK8a6r8dDDLXD349RvOMZOzvfHlNmlVv+SMb8eO50v6seOmSczf4exc03G3VnzZv7K5+QWSl+9IXkFBQC3O5p+PSFxMlAlPSgNYed/MtQlrx6EsadWgauOoDx09LvGxUZIUV7H/Xn9bvNnca1+g8mRk5cr7a3bJqNObSv1acQHNc/6fJ+t4niTEREtUsas4V7eJ76419y98vJHwAaBchA8/KYmx8v6t/c2Jo05CjPnG/ovnl8qU/UdNX5CmclB6Rm0svG2S7p7NkuTJkR6ezdIjynsCcWU6SbLOaSEb/ELJxoLm8rNoU0n4VEm7waO40oKHOno8Xzab41HS8s0HZdSUT+XJS7uZ4KH0/sVPNvnW6TxxXsDPTOkr8tXWQzJj2TZ5aFQXXxNT24a15O/jzpLGdeJl5bZ0+c0r3tqnpfcPkrdW7JDnFm4wZXV6ixRfYDjvTx/LnkzvNXp+eWYLU+OTmZ0rF/y/xea5p3/ZXV5estl8tilX9JRLejYv9XPc9eYq+e/qXfLE6G4yqluqNKhddlPbF5sPmBqlSRd3lZb1k+SWmd/InG+9IWLdH0bIJ+t+lr7tGpRZ+6M/7+7zidzx5kqzvfdW7ZS3b+pnmsvGvvS5qfl67fo+JoD8fDhHej+5QM5uU1/e+p++Egz6PhUJMidqaTlR6Nt2IMuUV7fmRc2JkUCPy2P//c6E4xsGtA317gBWED6KcU9mrtho94+lR3ZLA9ld0ECefnSiFDgiF72wRAoObpaOnh3SybNdOkTp/Q5p49ktyZ4s6e1ZL72j1gdsL8NJkg0mlDSXjU4z2Vi4vEsahFUoqap1ew/LL19aVuH1P9gWJR8t89bMfPT9Xt/zGnAGF4YGf/3/VNQMddELS8vc7v99vcOczL7cWtRx+L7/W+Nb1qYmvemJu03DWqa25O3/6Wf6vmjwcL/J623mDX2kb9sGpr+NdvzVmps/je0urRokyeUvLzfrXvzCZ+Y5N3ioTo/MC2jW0/4QOieM/vyhrOPy9U+HAvZJ5eTly2cb90ufNg1K1PJo8FAa1tTO9GMmmKns3AL5cU+mXPri5+bxl1sOltiuf/+fEwUB17c7MuTSFz+TvAJHGtaOl5d+fYapddl+6Jic37Hsflz7DmfL6u0ZMrhzY3lq3o8m9KUmJ5jh5/f/Z41ceXbLEpcycJsYX7iql1x4elOzf1pbNmvlTjm9eYp8tfWg3DKofcBlDUoLRlpmiXHRJsxWRF5+QYnh7lp7VJGaK//jqOWrNWzFg2bGsVxJiosutT+PNo2u33vY1EKqYIQP3Sctr9LeT8NpgeME1F6GkjbNLt2434T3mt7fqSxzv90t763aJU9f1p0maJuTjFVlkpKUlJQKTVJSlQli5syZI6NGjarweOTvdmWYZosRXVNNVfwDI7tIp9Q6Ac0E/1i62dQITCv8hh8nuSaAaBDpGLXd3Lf37JCWnn0S7Sn9cB914mWj01y2O41kv5MiPzt1Zb+kyH4n2Tw2N0mRHKFz56lMT1JZx/PlvA4NZeIvTpM/zf1R+rStb2p7/Guj/vnbs83jJwo7GX//+HA57dEPA7algSc967jpyPx+YaBSWkOz7WCWWR7etYkJE69/uV2uOaelbN6yVT7bG2U6PWvT2+S5P5a5r7qPL19zljlRa/Pe8GeLmu0qY9hpTQKCp6tFvUTThKX/t/xd1aelDOrUWKYsXC9rd2ZKfEyU3DigrYw5o4U8+cH3vk7Z/do1MB2su7eoK49d3NV3/aZFP+6T6/q3kdSUBJn68Ub5x9It8t9b+5vPq/RY3fb6SnloVGcZ3au5qRHV4KZBR5spuzRNNiFjwfd75fp/rZCr+7SUu4Z2lLP+sMB33LUGJ9rjkcPZeXLO5IVm2wsmnG+CTk5egSnjZ+avM8fdn5aHhljtf9avTT1ZvniBbK/dRRolJ8jlvVvKhLdWyaZ9R0xH77LmB9Jawh92Z8rieweaAKWhTWsQtWal3UNzzDo/PD5CYqI95oSvn0vz06GsXBPGNbjp71ZCbFFTpfu3UEfT/eqsNBO4KhJg9e+jbq8sOtJPTR5zugmkuk39YqdNt+FIzyHjnp8nn++Lkn9fd7ac16GRCfZ6YVL396esz6i/oxUNxDVVZc7fhI8KqOi3RF1PmxhOa5Ysi37YJ02S4+XC7s0kOSFG1u7KlBtfWSptPHvklQvryNvzFppA0t6zywSVOE/J4bOlyXQSTTA5IMly0EmWA3qTOn7LRc8flDqSL1z8Dph5fR/T/NXmQe/JtzQaGl774id5eJa374pL/x8vvf8C6froh3I8v+CE76UhY8gzJWvt/GmtmXbmPpHSLuvgfp75P+w1oe+5K3rK7W+sMqFp1u/6+Wq/Hrmwi6kBeaowQOrjP3zwwwmD37OX95Q731zlCwVvfrVd+rSpL39b4m1a/ts1Z8qR7Dy5++3Vcv+IznJpr+amc7oGEw26j110mow/t43pxH7Dv1aY92iakiC/6dda2jWqLTM+22Jqq645p5WvPO4a0lFaN0wytZEaOj++Z2CJmhCtNbzq78tNraGGwT9f1j1gxJ/S2s7YGI80qBVvAqKG16zcfDMCy60xO5qT56tVnP7ZFtPZ+9pzW8uYXi3kv6t3ypmt6smejGwT9ooHp1U/HZDR07y1nUpr9hb+6A27//ubs+S7XZny1wXrfWXUr31DX/gYeloT+evlPaV2Gf3W3lqxXZZu2C/j+rUyNbf6RVeDb1XMW7tH/rVsqzzzq55mosvGyfGm5tLtA6faNio9LJ0MwkeQw0ew6LdQrZrVESX6n8StZo6RPGnl2SsLf9NUFiz/Rn7ctEkaSYb0rH9cstP3SENPhjSUDIn35FX6PY84CXJM4k3H2CxJkKzCZfOcb9n7/DGzHGfWK1p2n/c+Pu7ESr5ESYFEmfuylgtME1J4fnsBKuvc9g3ks40HJFxoM5T/CLdws+ju8319rSpCL+TpNiG6tQTaVDdhaEdTU6Un4d6t65lResW987t+0iutruw4dEzOe7r00YEno1OTOub9s3LzTE3byXr1uj6y/VCWmapg7BktTGhcv/eIaUYtbt6d58mBI97zhjZFam3X3sxseeA/a2Rkt6YysFMjOZh13IQNrTHS46XcwKNBTo+Le5xOa5rs63f34xMjgn5ldsJHmIaP4j5Zt0/+599fm28AWr3rdrRbvvmAWdb0rd+gNMGvf2KESHaGXPnX96Xg8F5p4MmQ+p7D0kAypYEnU+p7MqWh3kumpMYckdr5mRJVRhOPDblOtAkuRzXYOAl+994QdNTRwKPBKE5ynRjJlRg5Lt577y1ajhc+730tVnL05sRKtsSZ5WzHe2+WJU7yTC0PgQfAqWfsGS3k3VU7Tef6qtKO9lq7pL56eIg0qhPceasIHzUkfFSEtr3qXB3uRF7avq6dGrXYnl2wwVSnzbz+HHnh4w3y4+7D8vJvzjLtpW0feF/qyhFJ9hyVJMmRRMkxI3P8lxMl2zz2LuvtuCR5sv2W3eeL1ok1p/gCifGcuPrZtnzHI3kSY2pdHHPTMUpFy+L3nH4KE2IKw4y5OYX3hWEmpzDc6PpubY7O6ZIvHlPD475PvuOt7XGDkm7bBCnHb7nweb3XnylwvNt0/LblvXlrj3Q5sCYpyrxPXrHnvDVMXv6xy/3ELu9n1p/3BjtCGhDZpl19how83X92b7vnb0a7hLninZi0A5kOI1WXnZXme/7e4Z0D1tMT08HC/h+uYV2ayIeF7bva5vzB7eeZ9s+bXvXWvuhohFe/+ElSEuOkdYMkc9G9Ry86Tc5/1nuNmtsuaG8uinfff9bIO9/sNKfK6MJblN+9nlLjJVd+eXpdyT6aKWu37DahprYcMyGmlgk92VLLky0Jpk4jT7QxJ9aTZ+ovYv1ucR7vvW7Pd/Pkmp9zl13amVdP75USoedgnTDPDUPe4BRtgkmeEx3QhFY87Oi9/pyGmdJCTmnP5Tv6PtG+n9V+SLpd73104XsWBS/9ae+9fyDzlP7YKQpuTkCI8z7233//EOf/efSdiqKld1lrDfWx93fau+zbRrGfL36M/IOvKh6EXb7XHf/nir1W7rHw/+z+N+/Pu8+bEnGK3j8wnLvHu/g2itYrKOe5kv+B/GN+4Kvu+yE83PzaN6afU6hQ83GKevS9tfKvwuF72gHsirNbmp7n2sFKRwVo+DjZKbB1SvcJb602HWozs0v2R/H/xdZfM21C0nk/ujZLNh2zyqKjGdw27ZlfeK+Bc/fQjmYYr87O6k//DGpgiS8MI7rs/VhFf+ob146TuGiP6RTn/hnVU64GmATJlQSP3uvPHzfLuh3va8dN+PE/pbgnI7fewbusz2t4yveFJt2+G6RiPH7LhfUd7inEd3LzeJ/z375/qAtYDmFzGhCMGkr/Gj63ls8NVKoo7riKlt1g4x+2/ENZ8ZBUGL8CApF/PHSf9w/X3lpG7/+6oqBc9JfAX2kB3OUfps1jp/SwVxEVXU+dKDjqXu9x6sn1fwjujOQ0u5QhksKHFmt6Vm6lZtqs6vtoiNEOUdpDW8ex63Ts2odlRLfUEuvqb5s7L8PiH/fIw2+vEIlNlB3p2fKH0d1Mz+zf9m/j6xH+5lfbzNDE3/RtbR7r0Emdp2Dr/izTE31Q4aRkWyaPkmWbDkha/SSpmxQrf5zzgxnGqMFozWPDzXbHvPiZmWlV55rQ69roHBI6qZn2ondDj05Zr/MyaE/+zYUTpan5dw0w095rJzh/o05PNc1gvz23jRm6qEMa3REF06/tLa3qJwV0vIuLjjIdvdoWDnl06VDued+VP8Op35EsDDzeP7ml/4ku+hatfAFJ8k0A0psbjLzLeu+tgSoKOUWBp3HtWDl05JhfEHJKfEv3X3Zf00e6fX1vbRBz38/dF9/zki/1kmIk89hx3+fz/1PvPS15P7P75zRwncLjYWosivbRW89SdKwCQ5x3W24zWPFmL//aFF2n9CBYUPgeRc/5n9CKn+iklBOhK7BxMHBdN+S6IdX3uMRxKSoF9+cCT3FF2zPHjCAbsTYWNJP2jweOfjpZhI8yRFL4qAnc8hgxYqQczvVOYlVZn2/aL41qx0uHJkVzr7i0w25yYky54UuH4s1Zu1t+0b1ZiWFtOqW7zto6oGND3wRdOqZfFb9gnz/9L6XzOLjD+XSiLu3drvMmnN+xsemTo1PBb/r5iHy15aCZc0PH/+fmO6K5TPuTZZtJwaJ8E5Vd2L2pbxr7suhQSh1W2TOtrqz4yTsRmT8dIqnNeO7kbR0a1zYz+bqz3F54eqrs37NLnFoNzBDBm85vJzFRHjOtvQZGDX7FZ6h1ndGybsBsuRoG/Ye1Pn5JV5m9enfApG/+Pr1vkDSvm1gilFUnvZbOiaa0D6bLz0qTjql1fHOzFPfMr3qYmsSy9GiRIqt3FI1w+dVZLcz8LyevKJy4vYj0Xoo9dkNg4HpFIdTdkve+5HPFA5MJTSYAFj1239d/r/xDrv/23f3wD1j+23EfF6/jKN48VbR/GoK9IdIXmP3CpenrJvkVrpUoXrdR2rGMqkCzpfe5wGNaHrPtwmbDgOBa7DjVa9BI7pwwUYKJ8FEGwkd4oTxOTPvk6H9QrQl6aNa38tOBo6b25Ko+rWRIl8byz8+3yp/mrZNXr+9j5idwfb8r00zZrcHJHZ7oNoN9uuFnWbfnsJloS2utDhzJMUP4kmKkQuWhtUga+nTeAp1nQof+6XWRtMZo1jc7TBA8t31Ds64GFv/hfG6NnA6f1NolrTl68eqii+lt3X/UzKh644B28qd5pU9wNqZXc7nuvDamQn3Uc97+SK7x/VqbeSl0xsxPN+43wxj1mkNv3NjXTLylk2jpfBXaf6lZ3URzPM//c9FVozUAay2edvLWQKrNkzr9/e0XdDCzkep8FWcWTiamnhpzuqmB0zkp9I++Xln6+St7mVq3g0dz5JXxvX0hbM7t55k5gNo/NMfMv1Gclo/WwmkNmM6oWztWZMRf5stPRzy+ywGMeHaJuQLyP8adJYO7NCkzEGr407J1h1xqsH70F6fJ6h3pZpI0nahMm2KVzsexdOPPZijxjGt7B1wL6utHhphjvDczxzzWSyfo79+9xWbmHXNGczOnhE6ipp/Nf5/aNqoVUIs4+7b+ZqbgrRWY56SyHhjZWbo1S5F7/2+1/G5Qe9/1hh4c2VkGdGwkv3ppmbRqmGQuzaDNuHoNJC3j8uZ/cYP6hjIuOfHSr880AwFeWLQx4NITOkRX30PViY+RP1/Ww/SvCxddmiabWuBgInyUgZNdeKE8gkOH3pU1I6T+977n7TXmqsRakxFO5aGTHelsq2XN1Lnwh71mmvhbLmhv5jEY2qWJmTk0rV6Sr+lOm860b5DOJqr3dw7p4JshUz+71irpe5TXv6nb7z+UIzl5vvkidKbPWvHRZdZuaa2VnvQ1yJS17/50YiwNNBo+dH2tPcs8lieXv7xM2jasLRf1aCodm9QxJ4PSyuO8C4ZK/TremUa135Z+Jv9Zlq+d/qX8fCRH7hjc0Wx7VLemvuMzb+1umbZ4s0y5vKe0blixCyW6ITI1JdE0neqJ9Zuf0k2Icz+vhjYNQXr20BBa/Pdv9NTPZNX2dDM/il5s0g1B2jypndi1bHRuC720wYqtB2X5loNmMrKV2w6ZE/j0z7aWCISDOjWSIac1kaycfGmSkiC3v75SPri9v7RuUMvUMmpw02kJ/OnxUu5kYfr7UysupsT+uvv39k19zZW7dVsdH5nre12bSnWCuUunfua7bpXSmYA1VCj9mRFTlpgqioV3n+/7nfP//+kfFt/6n76m/9sz89ebWlb/Wiy94vikS7rJrJU75K43V0vn1DrmgpLatHvjeW3NSMczWtWV385YIc1SEmRXRtH1oTRkSmEQdr0y/izzO6YhWa9d9fmmA+b6WYvuGSjBRPgoAye78EJ5hJdILQ892Wrtjf5xrg7F+zq53GnNywpGFS2P0q5sHGp6DRm9AKJ2btfp+7UJVGtdNEDavtpyRWzcd9iECr0Ct0tnZ9VaIe38/v5t/c1zR4/lyDuz58nPKZ1ky4FjJtT5f54TlWlZ5ZWTly8XP/+ZpNVPNH3PdNp2DQq6roY8rT0qr6lXf4d/848vzDWedGIy3bZetHLKwg2myVMntnRpx/u3V+wwoU7fK5gYagsAFaRNElWdxroi9ERQ2rkoWCfhcAodLp286vrzii6Sp9fRSU0J7IAeTvTk7H+CVnpVbA1QWuPgMhcOjBO5YlC7UgNhxa7+XHIdDRZuaCi+bvEasdLo7+97t3oDkqtP2wYys61esDRQ05REuX1wBwk1wgcAAKWc0N1RdjZ4wjBEVidmfAEAAFYRPgAAgFWEDwAAYBXhAwAAWEX4AAAAVhE+AACAVYQPAABgFeEDAABYRfgAAABWET4AAIBVhA8AAGAV4QMAAFhF+AAAAJF9VVvHccx9ZmZm0Ledm5srWVlZZtulXQ4ZdlEe4YXyCC+UR/ihTMrnnrfd83iNCh+HDx8292lpaaHeFQAAUIXzeEpKSrnreJyKRBSLCgoKZNeuXVKnTh3xeDxBT2UaarZv3y7JyclB3TYqj/IIL5RHeKE8wg9lUj6NExo8mjVrJlFRUTWr5kN3uEWLFtX6HvpLwy9O+KA8wgvlEV4oj/BDmZTtRDUeLjqcAgAAqwgfAADAqogKH/Hx8fL73//e3CP0KI/wQnmEF8oj/FAmwRN2HU4BAMCpLaJqPgAAQOgRPgAAgFWEDwAAYBXhAwAAWBUx4WPq1KnSunVrSUhIkD59+siXX34Z6l06JUyePFl69+5tZqRt3LixjB49WtatWxewTnZ2ttxyyy3SoEEDqV27towdO1b27t0bsM62bdvkwgsvlKSkJLOde++9V/Ly8gLW+eSTT+SMM84wPc3bt28vM2bMsPIZa6qnnnrKzBJ85513+p6jLOzbuXOn/PrXvzbHPDExUU4//XRZsWKF73Xt8//oo49K06ZNzetDhgyRDRs2BGzj4MGDcvXVV5uJrerWrSvXXXedHDlyJGCdNWvWyHnnnWf+xuksnE8//bS1z1hT5Ofny8SJE6VNmzbmWLdr106eeOKJgGuRUB6WOBHgjTfecOLi4pxXXnnF+e6775wbbrjBqVu3rrN3795Q71qNN3z4cGf69OnO2rVrnVWrVjmjRo1yWrZs6Rw5csS3zk033eSkpaU5CxcudFasWOGcc845Tr9+/Xyv5+XlOd26dXOGDBnirFy50pkzZ47TsGFD58EHH/Sts3nzZicpKcmZMGGC8/333zvPP/+8Ex0d7cybN8/6Z64JvvzyS6d169ZO9+7dnTvuuMP3PGVh18GDB51WrVo548ePd7744gtz7D788ENn48aNvnWeeuopJyUlxXn33Xed1atXOxdffLHTpk0b59ixY751RowY4fTo0cNZvny58+mnnzrt27d3rrzySt/rGRkZTpMmTZyrr77a/F98/fXXncTEROdvf/ub9c8czp588kmnQYMGzuzZs50tW7Y4b7/9tlO7dm1nypQpvnUoDzsiInycffbZzi233OJ7nJ+f7zRr1syZPHlySPfrVLRv3z79CuEsXrzYPE5PT3diY2PNf3LXDz/8YNZZtmyZeawnuKioKGfPnj2+daZNm+YkJyc7OTk55vF9993ndO3aNeC9Lr/8chN+EOjw4cNOhw4dnPnz5zvnn3++L3xQFvbdf//9Tv/+/ct8vaCgwElNTXX+/Oc/+57TcoqPjzcnLKUBT8voq6++8q0zd+5cx+PxODt37jSPX3zxRadevXq+MnLfu1OnTtX0yWqmCy+80Pntb38b8NyYMWNMSFCUhz2nfLPL8ePH5euvvzZVZ/7Xj9HHy5YtC+m+nYoyMjLMff369c29Hnu9DLX/8e/cubO0bNnSd/z1XquimzRp4ltn+PDh5iJO3333nW8d/22461CGJWmzijabFD9elIV9//3vf+Wss86Syy67zDRh9erVS/7+97/7Xt+yZYvs2bMn4HjqtTG0adi/TLRqX7fj0vX179gXX3zhW2fAgAESFxcXUCbaBHro0CFLnzb89evXTxYuXCjr1683j1evXi1Lly6VkSNHmseUhz1hd2G5YNu/f79p5/P/Y6r08Y8//hiy/ToV6RWJtX/BueeeK926dTPP6X9k/Q+o/1mLH399zV2ntPJxXytvHT0pHjt2zLTNQuSNN96Qb775Rr766qsSr1EW9m3evFmmTZsmEyZMkIceesiUy+23327KYdy4cb5jWtrx9D/eGlz8xcTEmIDvv472Yyi+Dfe1evXqVevnrCkeeOAB83uqoTs6OtqcG5588knTf0NRHvac8uEDdr9xr1271nyTgH16me877rhD5s+fbzq5ITwCuX5D/uMf/2gea82H/h956aWXTPiAXW+99Za89tprMnPmTOnatausWrXKfGHSS8BTHnad8s0uDRs2NAm3eI9+fZyamhqy/TrV3HrrrTJ79mz5+OOPpUWLFr7n9Rhr01d6enqZx1/vSysf97Xy1tHe5nzTLmpW2bdvnxmFot/E9LZ48WJ57rnnzLJ+86Is7NIRE6eddlrAc126dDEjivyPaXl/n/Rey9Wfjj7SEReVKTeIGbmltR9XXHGFaV685ppr5K677jKj9hTlYc8pHz60evPMM8807Xz+30b0cd++fUO6b6cC7bSswWPWrFmyaNGiElWNeuxjY2MDjr+2e+ofX/f46/23334b8B9av73rycz9w63r+G/DXYcyLDJ48GBzHPXbnHvTb91apewuUxZ2aRNk8aHn2t+gVatWZln/v+jJyP94arOA9h3wLxMNjBouXfp/Tf+OaV8Ed50lS5aYPj3+ZdKpUyeq+P1kZWWZvhn+9MupHktFeVjkRMhQW+2tPGPGDNNT+cYbbzRDbf179KNqbr75ZjMs7ZNPPnF2797tu2VlZQUM79Tht4sWLTLDO/v27WtuxYd3Dhs2zAzX1SGbjRo1KnV457333mtGaEydOpXhnRXgP9pFURb2hzzHxMSYIZ4bNmxwXnvtNXPsXn311YChnfr36L333nPWrFnjXHLJJaUO7ezVq5cZrrt06VIzmsl/aKeOyNChnddcc40Z2ql/8/R9GNoZaNy4cU7z5s19Q23feecdM5RcR3C5KA87IiJ8KJ2LQP/o6nwfOvRWx2fj5Gl+Le2mc3+49D/t7373OzP0TP8DXnrppSag+Nu6daszcuRIMxZe/xjcfffdTm5ubsA6H3/8sdOzZ09Thm3btg14D1QsfFAW9r3//vsm0OkXoM6dOzsvv/xywOs6vHPixInmZKXrDB482Fm3bl3AOgcOHDAnN52TQoc9X3vttWZItT+dk0KH9eo29ASrJ1EEyszMNP8f9FyQkJBgfncffvjhgCGxlIcdHv3HZk0LAACIbKd8nw8AABBeCB8AAMAqwgcAALCK8AEAAKwifAAAAKsIHwAAwCrCBwAAsIrwAQAArCJ8AAAAqwgfAADAKsIHAACwivABAADEpv8P8hT0yNnu/LMAAAAASUVORK5CYII="
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "execution_count": 38
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "# 推理",
   "id": "bb9b74d3121fcad"
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-03-10T02:33:29.972343Z",
     "start_time": "2025-03-10T02:33:29.681593Z"
    }
   },
   "cell_type": "code",
   "source": [
    "# 模型上线\n",
    "model = Sequence2Sequence(len(src_word2idx), len(trg_word2idx))\n",
    "model.load_state_dict(torch.load(f\"./checkpoints/translate-seq2seq/best.ckpt\", map_location=\"cpu\"))\n",
    "\n",
    "\n",
    "class Translator:\n",
    "    def __init__(self, model, src_tokenizer, trg_tokenizer):\n",
    "        self.model = model\n",
    "        self.model.eval()  # 切换到验证模式\n",
    "        self.src_tokenizer = src_tokenizer\n",
    "        self.trg_tokenizer = trg_tokenizer\n",
    "\n",
    "    def draw_attention_map(self, scores, src_words_list, trg_words_list):\n",
    "        \"\"\"\n",
    "        绘制注意力热力图\n",
    "        Args:\n",
    "            - scores (numpy.ndarray): shape = [source sequence length, target sequence length]\n",
    "        \"\"\"\n",
    "        plt.matshow(scores.T, cmap='viridis')  # 注意力矩阵,显示注意力分数值\n",
    "        # 获取当前的轴\n",
    "        ax = plt.gca()\n",
    "        # 设置热图中每个单元格的分数的文本\n",
    "        for i in range(scores.shape[0]):  #输入\n",
    "            for j in range(scores.shape[1]):  #输出\n",
    "                ax.text(j, i, f'{scores[i, j]:.2f}',  # 格式化数字显示\n",
    "                        ha='center', va='center', color='k')\n",
    "        plt.xticks(range(scores.shape[0]), src_words_list)\n",
    "        plt.yticks(range(scores.shape[1]), trg_words_list)\n",
    "        plt.show()\n",
    "\n",
    "    def __call__(self, sentence):\n",
    "        sentence = preprocess_sentence(sentence)  # 预处理句子，标点符号处理等\n",
    "        encoder_input, attn_mask = self.src_tokenizer.encode(\n",
    "            [sentence.split()],\n",
    "            padding_first=True,\n",
    "            add_bos=True,\n",
    "            add_eos=True,\n",
    "            return_mask=True,\n",
    "        )  # 对输入进行编码，并返回encode_piadding_mask\n",
    "        encoder_input = torch.Tensor(encoder_input).to(dtype=torch.int64)  # 转换成tensor\n",
    "        preds, scores = model.infer(encoder_input=encoder_input, attn_mask=attn_mask)  #预测\n",
    "        trg_sentence = self.trg_tokenizer.decode([preds], split=True, remove_eos=False)[0]  #通过tokenizer转换成文字\n",
    "        src_decoded = self.src_tokenizer.decode(\n",
    "            encoder_input.tolist(),\n",
    "            split=True,\n",
    "            remove_bos=False,\n",
    "            remove_eos=False\n",
    "        )[0]  #对输入编码id进行解码，转换成文字,为了画图\n",
    "        self.draw_attention_map(\n",
    "            scores.squeeze(0).numpy(),\n",
    "            src_decoded,  # 注意力图的源句子\n",
    "            trg_sentence  # 注意力图的目标句子\n",
    "        )\n",
    "        return \" \".join(trg_sentence[:-1])"
   ],
   "id": "28007a3a21f53b8f",
   "outputs": [],
   "execution_count": 53
  },
  {
   "metadata": {
    "jupyter": {
     "is_executing": true
    },
    "ExecuteTime": {
     "start_time": "2025-03-10T02:40:09.441636Z"
    }
   },
   "cell_type": "code",
   "source": [
    "model = Sequence2Sequence(len(src_word2idx), len(trg_word2idx))\n",
    "model.load_state_dict(torch.load(f\"./checkpoints/translate-seq2seq/best.ckpt\", map_location=\"cpu\"))\n",
    "\n",
    "\n",
    "class Translator:\n",
    "    def __init__(self, model, src_tokenizer, trg_tokenizer):\n",
    "        self.model = model\n",
    "        self.model.eval()  # 切换到验证模式\n",
    "        self.src_tokenizer = src_tokenizer\n",
    "        self.trg_tokenizer = trg_tokenizer\n",
    "\n",
    "    def __call__(self, sentence):\n",
    "        sentence = preprocess_sentence(sentence)  # 预处理句子，标点符号处理等\n",
    "        encoder_input, attn_mask = self.src_tokenizer.encode(\n",
    "            [sentence.split()],\n",
    "            padding_first=True,\n",
    "            add_bos=True,\n",
    "            add_eos=True,\n",
    "            return_mask=True,\n",
    "        )  # 对输入进行编码，并返回encode_piadding_mask\n",
    "        encoder_input = torch.Tensor(encoder_input).to(dtype=torch.int64)  # 转换成tensor\n",
    "\n",
    "        preds, scores = model.infer(encoder_input=encoder_input, attn_mask=attn_mask)  #预测\n",
    "\n",
    "        trg_sentence = self.trg_tokenizer.decode([preds], split=True, remove_eos=False)[0]  #通过tokenizer转换成文字\n",
    "\n",
    "        return \" \".join(trg_sentence[:-1])\n",
    "\n",
    "\n",
    "from nltk.translate.bleu_score import sentence_bleu\n",
    "\n",
    "\n",
    "def evaluate_bleu_on_test_set(test_data, translator):\n",
    "    \"\"\"\n",
    "    在测试集上计算平均 BLEU 分数。\n",
    "    :param test_data: 测试集数据，格式为 [(src_sentence, [ref_translation1, ref_translation2, ...]), ...]\n",
    "    :param translator: 翻译器对象（Translator 类的实例）\n",
    "    :return: 平均 BLEU 分数\n",
    "    \"\"\"\n",
    "    total_bleu = 0.0\n",
    "    num_samples = len(test_data)\n",
    "    i = 0\n",
    "    for src_sentence, ref_translations in test_data:\n",
    "        # 使用翻译器生成翻译结果\n",
    "        candidate_translation = translator(src_sentence)\n",
    "\n",
    "        # 计算 BLEU 分数\n",
    "        bleu_score = sentence_bleu([ref_translations.split()], candidate_translation.split(), weights=(1, 0, 0, 0))\n",
    "        total_bleu += bleu_score\n",
    "\n",
    "        # 打印当前句子的 BLEU 分数（可选）\n",
    "        # print(f\"Source: {src_sentence}\")\n",
    "        # print(f\"Reference: {ref_translations}\")\n",
    "        # print(f\"Candidate: {candidate_translation}\")\n",
    "        # print(f\"BLEU: {bleu_score:.4f}\")\n",
    "        # print(\"-\" * 50)\n",
    "        # i+=1\n",
    "        # if i>10:\n",
    "        #     break\n",
    "    # 计算平均 BLEU 分数\n",
    "    avg_bleu = total_bleu / num_samples\n",
    "    return avg_bleu\n",
    "\n",
    "\n",
    "translator = Translator(model.cpu(), src_tokenizer, trg_tokenizer)\n",
    "evaluate_bleu_on_test_set(test_ds, translator)"
   ],
   "id": "8e0e659793c582d8",
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\35493\\AppData\\Roaming\\Python\\Python312\\site-packages\\nltk\\translate\\bleu_score.py:577: UserWarning: \n",
      "The hypothesis contains 0 counts of 2-gram overlaps.\n",
      "Therefore the BLEU score evaluates to 0, independently of\n",
      "how many N-gram overlaps of lower order it contains.\n",
      "Consider using lower n-gram order or use SmoothingFunction()\n",
      "  warnings.warn(_msg)\n",
      "C:\\Users\\35493\\AppData\\Roaming\\Python\\Python312\\site-packages\\nltk\\translate\\bleu_score.py:577: UserWarning: \n",
      "The hypothesis contains 0 counts of 3-gram overlaps.\n",
      "Therefore the BLEU score evaluates to 0, independently of\n",
      "how many N-gram overlaps of lower order it contains.\n",
      "Consider using lower n-gram order or use SmoothingFunction()\n",
      "  warnings.warn(_msg)\n",
      "C:\\Users\\35493\\AppData\\Roaming\\Python\\Python312\\site-packages\\nltk\\translate\\bleu_score.py:577: UserWarning: \n",
      "The hypothesis contains 0 counts of 4-gram overlaps.\n",
      "Therefore the BLEU score evaluates to 0, independently of\n",
      "how many N-gram overlaps of lower order it contains.\n",
      "Consider using lower n-gram order or use SmoothingFunction()\n",
      "  warnings.warn(_msg)\n"
     ]
    }
   ],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "",
   "id": "dfc6518375badd8f"
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
